diff --git a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/keygen/TimestampBasedAvroKeyGenerator.java b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/keygen/TimestampBasedAvroKeyGenerator.java
index bce7e24c57a5f..0c788188a9c02 100644
--- a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/keygen/TimestampBasedAvroKeyGenerator.java
+++ b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/keygen/TimestampBasedAvroKeyGenerator.java
@@ -65,6 +65,29 @@ public enum TimestampType implements Serializable {
protected final boolean encodePartitionPath;
+ /**
+ * Supported configs.
+ */
+ public static class Config {
+
+ // One value from TimestampType above
+ public static final String TIMESTAMP_TYPE_FIELD_PROP = "hoodie.deltastreamer.keygen.timebased.timestamp.type";
+ public static final String INPUT_TIME_UNIT =
+ "hoodie.deltastreamer.keygen.timebased.timestamp.scalar.time.unit";
+ //This prop can now accept list of input date formats.
+ public static final String TIMESTAMP_INPUT_DATE_FORMAT_PROP =
+ "hoodie.deltastreamer.keygen.timebased.input.dateformat";
+ public static final String TIMESTAMP_INPUT_DATE_FORMAT_LIST_DELIMITER_REGEX_PROP = "hoodie.deltastreamer.keygen.timebased.input.dateformat.list.delimiter.regex";
+ public static final String TIMESTAMP_INPUT_TIMEZONE_FORMAT_PROP = "hoodie.deltastreamer.keygen.timebased.input.timezone";
+ public static final String TIMESTAMP_OUTPUT_DATE_FORMAT_PROP =
+ "hoodie.deltastreamer.keygen.timebased.output.dateformat";
+ //still keeping this prop for backward compatibility so that functionality for existing users does not break.
+ public static final String TIMESTAMP_TIMEZONE_FORMAT_PROP =
+ "hoodie.deltastreamer.keygen.timebased.timezone";
+ public static final String TIMESTAMP_OUTPUT_TIMEZONE_FORMAT_PROP = "hoodie.deltastreamer.keygen.timebased.output.timezone";
+ static final String DATE_TIME_PARSER_PROP = "hoodie.deltastreamer.keygen.datetime.parser.class";
+ }
+
public TimestampBasedAvroKeyGenerator(TypedProperties config) throws IOException {
this(config, config.getString(KeyGeneratorOptions.RECORDKEY_FIELD_NAME.key()),
config.getString(KeyGeneratorOptions.PARTITIONPATH_FIELD_NAME.key()));
diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala
index 4cc1bca1d614d..8e4c6376dd851 100644
--- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala
+++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala
@@ -27,7 +27,11 @@ import org.apache.spark.sql.hudi.SparkAdapter
trait SparkAdapterSupport {
lazy val sparkAdapter: SparkAdapter = {
- val adapterClass = if (HoodieSparkUtils.isSpark3) {
+ val adapterClass = if (HoodieSparkUtils.isSpark3_1) {
+ "org.apache.spark.sql.adapter.Spark3_1Adapter"
+ } else if (HoodieSparkUtils.isSpark3_2) {
+ "org.apache.spark.sql.adapter.Spark3_2Adapter"
+ } else if (HoodieSparkUtils.isSpark3) {
"org.apache.spark.sql.adapter.Spark3Adapter"
} else {
"org.apache.spark.sql.adapter.Spark2Adapter"
diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala
index 20b4d3cc1be73..9509e3fef4424 100644
--- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala
+++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.hudi
-import org.apache.hudi.HoodieSparkUtils.sparkAdapter
import org.apache.hudi.client.utils.SparkRowSerDe
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.catalog.CatalogTable
diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala
index 12bbd64851001..17f3c1bc4f217 100644
--- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala
+++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala
@@ -17,17 +17,19 @@
package org.apache.spark.sql.hudi.analysis
+import org.apache.hudi.DataSourceReadOptions
import org.apache.hudi.DataSourceWriteOptions.MOR_TABLE_TYPE_OPT_VAL
import org.apache.hudi.common.model.HoodieRecord
import org.apache.hudi.common.util.ReflectionUtils
import org.apache.hudi.{HoodieSparkUtils, SparkAdapterSupport}
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.catalog.{CatalogUtils, HoodieCatalogTable}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command._
-import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation}
+import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation}
import org.apache.spark.sql.hudi.HoodieSqlCommonUtils.{getTableIdentifier, removeMetaFields}
import org.apache.spark.sql.hudi.HoodieSqlUtils._
import org.apache.spark.sql.hudi.command._
@@ -110,6 +112,7 @@ case class HoodieAnalysis(sparkSession: SparkSession) extends Rule[LogicalPlan]
case _ =>
l
}
+
// Convert to CreateHoodieTableAsSelectCommand
case CreateTable(table, mode, Some(query))
if query.resolved && sparkAdapter.isHoodieTable(table) =>
@@ -133,7 +136,7 @@ case class HoodieAnalysis(sparkSession: SparkSession) extends Rule[LogicalPlan]
// Convert to CompactionShowHoodiePathCommand
case CompactionShowOnPath(path, limit) =>
CompactionShowHoodiePathCommand(path, limit)
- case _=> plan
+ case _ => plan
}
}
}
@@ -350,6 +353,35 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi
l
}
+ case TimeTravelRelation(plan: UnresolvedRelation, timestamp, version) =>
+ // TODO: How to use version to perform time travel?
+ if (timestamp.isEmpty && version.nonEmpty) {
+ throw new AnalysisException(
+ "version expression is not support for time travel")
+ }
+
+ val tableIdentifier = sparkAdapter.toTableIdentifier(plan)
+ if (sparkAdapter.isHoodieTable(tableIdentifier, sparkSession)) {
+ val hoodieCatalogTable = HoodieCatalogTable(sparkSession, tableIdentifier)
+ val table = hoodieCatalogTable.table
+ val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_))
+ val instantOption = Map(
+ DataSourceReadOptions.TIME_TRAVEL_AS_OF_INSTANT.key -> timestamp.get.toString())
+ val dataSource =
+ DataSource(
+ sparkSession,
+ userSpecifiedSchema = if (table.schema.isEmpty) None else Some(table.schema),
+ partitionColumns = table.partitionColumnNames,
+ bucketSpec = table.bucketSpec,
+ className = table.provider.get,
+ options = table.storage.properties ++ pathOption ++ instantOption,
+ catalogTable = Some(table))
+
+ LogicalRelation(dataSource.resolveRelation(checkFilesExist = false), table)
+ } else {
+ plan
+ }
+
case p => p
}
diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestTimeTravelParser.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestTimeTravelParser.scala
new file mode 100644
index 0000000000000..4ad6674a300eb
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestTimeTravelParser.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hudi
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.plans.logical.{Project, TimeTravelRelation}
+
+class TestTimeTravelParser extends TestHoodieSqlBase {
+ private val parser = spark.sessionState.sqlParser
+
+ test("time travel of timestamp") {
+ val timeTravelPlan1 = parser.parsePlan("SELECT * FROM A.B " +
+ "TIMESTAMP AS OF '2019-01-29 00:37:58'")
+
+ assertResult(Project(Seq(UnresolvedStar(None)),
+ TimeTravelRelation(
+ UnresolvedRelation(new TableIdentifier("B", Option.apply("A"))),
+ Some(Literal("2019-01-29 00:37:58")),
+ None))) {
+ timeTravelPlan1
+ }
+
+ val timeTravelPlan2 = parser.parsePlan("SELECT * FROM A.B " +
+ "TIMESTAMP AS OF 1643119574")
+
+ assertResult(Project(Seq(UnresolvedStar(None)),
+ TimeTravelRelation(
+ UnresolvedRelation(new TableIdentifier("B", Option.apply("A"))),
+ Some(Literal(1643119574)),
+ None))) {
+ timeTravelPlan2
+ }
+ }
+
+ test("time travel of version") {
+ val timeTravelPlan1 = parser.parsePlan("SELECT * FROM A.B " +
+ "VERSION AS OF 'Snapshot123456789'")
+
+ assertResult(Project(Seq(UnresolvedStar(None)),
+ TimeTravelRelation(
+ UnresolvedRelation(new TableIdentifier("B", Option.apply("A"))),
+ None,
+ Some("Snapshot123456789")))) {
+ timeTravelPlan1
+ }
+
+ val timeTravelPlan2 = parser.parsePlan("SELECT * FROM A.B " +
+ "VERSION AS OF 'Snapshot01'")
+
+ assertResult(Project(Seq(UnresolvedStar(None)),
+ TimeTravelRelation(
+ UnresolvedRelation(new TableIdentifier("B", Option.apply("A"))),
+ None,
+ Some("Snapshot01")))) {
+ timeTravelPlan2
+ }
+ }
+}
diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestTimeTravelTable.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestTimeTravelTable.scala
new file mode 100644
index 0000000000000..eaebd2d96f084
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestTimeTravelTable.scala
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hudi
+
+import org.apache.hudi.common.table.HoodieTableMetaClient
+
+class TestTimeTravelTable extends TestHoodieSqlBase {
+ test("Test Insert and Update with time travel") {
+ withTempDir { tmp =>
+ val tableName1 = generateTableName
+ spark.sql(
+ s"""
+ |create table $tableName1 (
+ | id int,
+ | name string,
+ | price double,
+ | ts long
+ |) using hudi
+ | tblproperties (
+ | type = 'cow',
+ | primaryKey = 'id',
+ | preCombineField = 'ts'
+ | )
+ | location '${tmp.getCanonicalPath}/$tableName1'
+ """.stripMargin)
+
+ spark.sql(s"insert into $tableName1 values(1, 'a1', 10, 1000)")
+
+ val metaClient1 = HoodieTableMetaClient.builder()
+ .setBasePath(s"${tmp.getCanonicalPath}/$tableName1")
+ .setConf(spark.sessionState.newHadoopConf())
+ .build()
+
+ val instant1 = metaClient1.getActiveTimeline.getAllCommitsTimeline
+ .lastInstant().get().getTimestamp
+
+ spark.sql(s"insert into $tableName1 values(1, 'a2', 20, 2000)")
+
+ checkAnswer(s"select id, name, price, ts from $tableName1")(
+ Seq(1, "a2", 20.0, 2000)
+ )
+
+ // time travel from instant1
+ checkAnswer(
+ s"select id, name, price, ts from $tableName1 TIMESTAMP AS OF '$instant1'")(
+ Seq(1, "a1", 10.0, 1000)
+ )
+ }
+ }
+
+ test("Test Two Table's Union Join with time travel") {
+ withTempDir { tmp =>
+ Seq("cow", "mor").foreach { tableType =>
+ val tableName = generateTableName
+
+ val basePath = tmp.getCanonicalPath
+ val tableName1 = tableName + "_1"
+ val tableName2 = tableName + "_2"
+ val path1 = s"$basePath/$tableName1"
+ val path2 = s"$basePath/$tableName2"
+
+ spark.sql(
+ s"""
+ |create table $tableName1 (
+ | id int,
+ | name string,
+ | price double,
+ | ts long
+ |) using hudi
+ | tblproperties (
+ | type = '$tableType',
+ | primaryKey = 'id',
+ | preCombineField = 'ts'
+ | )
+ | location '$path1'
+ """.stripMargin)
+
+ spark.sql(
+ s"""
+ |create table $tableName2 (
+ | id int,
+ | name string,
+ | price double,
+ | ts long
+ |) using hudi
+ | tblproperties (
+ | type = '$tableType',
+ | primaryKey = 'id',
+ | preCombineField = 'ts'
+ | )
+ | location '$path2'
+ """.stripMargin)
+
+ spark.sql(s"insert into $tableName1 values(1, 'a1', 10, 1000)")
+ spark.sql(s"insert into $tableName1 values(2, 'a2', 20, 1000)")
+
+ checkAnswer(s"select id, name, price, ts from $tableName1")(
+ Seq(1, "a1", 10.0, 1000),
+ Seq(2, "a2", 20.0, 1000)
+ )
+
+ checkAnswer(s"select id, name, price, ts from $tableName1")(
+ Seq(1, "a1", 10.0, 1000),
+ Seq(2, "a2", 20.0, 1000)
+ )
+
+ spark.sql(s"insert into $tableName2 values(3, 'a3', 10, 1000)")
+ spark.sql(s"insert into $tableName2 values(4, 'a4', 20, 1000)")
+
+ checkAnswer(s"select id, name, price, ts from $tableName2")(
+ Seq(3, "a3", 10.0, 1000),
+ Seq(4, "a4", 20.0, 1000)
+ )
+
+ val metaClient1 = HoodieTableMetaClient.builder()
+ .setBasePath(path1)
+ .setConf(spark.sessionState.newHadoopConf())
+ .build()
+
+ val metaClient2 = HoodieTableMetaClient.builder()
+ .setBasePath(path2)
+ .setConf(spark.sessionState.newHadoopConf())
+ .build()
+
+ val instant1 = metaClient1.getActiveTimeline.getAllCommitsTimeline
+ .lastInstant().get().getTimestamp
+
+ val instant2 = metaClient2.getActiveTimeline.getAllCommitsTimeline
+ .lastInstant().get().getTimestamp
+
+ val sql =
+ s"""
+ |select id, name, price, ts from $tableName1 TIMESTAMP AS OF '$instant1' where id=1
+ |union
+ |select id, name, price, ts from $tableName2 TIMESTAMP AS OF '$instant2' where id>1
+ |""".stripMargin
+
+ checkAnswer(sql)(
+ Seq(1, "a1", 10.0, 1000),
+ Seq(3, "a3", 10.0, 1000),
+ Seq(4, "a4", 20.0, 1000)
+ )
+ }
+ }
+ }
+
+ test("Test Insert Into with time travel") {
+ withTempDir { tmp =>
+ // Create Non-Partitioned table
+ val tableName1 = generateTableName
+ spark.sql(
+ s"""
+ |create table $tableName1 (
+ | id int,
+ | name string,
+ | price double,
+ | ts long
+ |) using hudi
+ | tblproperties (
+ | type = 'cow',
+ | primaryKey = 'id',
+ | preCombineField = 'ts'
+ | )
+ | location '${tmp.getCanonicalPath}/$tableName1'
+ """.stripMargin)
+
+ spark.sql(s"insert into $tableName1 values(1, 'a1', 10, 1000)")
+
+ val metaClient1 = HoodieTableMetaClient.builder()
+ .setBasePath(s"${tmp.getCanonicalPath}/$tableName1")
+ .setConf(spark.sessionState.newHadoopConf())
+ .build()
+
+ val instant1 = metaClient1.getActiveTimeline.getAllCommitsTimeline
+ .lastInstant().get().getTimestamp
+
+
+ val tableName2 = generateTableName
+ // Create a partitioned table
+ spark.sql(
+ s"""
+ |create table $tableName2 (
+ | id int,
+ | name string,
+ | price double,
+ | ts long,
+ | dt string
+ |) using hudi
+ | tblproperties (primaryKey = 'id')
+ | partitioned by (dt)
+ | location '${tmp.getCanonicalPath}/$tableName2'
+ """.stripMargin)
+
+ // Insert into dynamic partition
+ spark.sql(
+ s"""
+ | insert into $tableName2
+
+ | select id, name, price, ts, '2022-
+
+ | from $tableName1 TIMESTAMP AS OF '$instant1'
+ """.
+ stripMargin)
+ checkAnswer(s"select id, name, price, ts, dt from $tableName2")(
+ Seq(1, "a1", 10.0, 1000, "2022-02-14")
+ )
+
+ // Ins static partition
+ spark.sql(
+ s"""
+ | insert into $tableName2 partition(dt = '2022-02-15')
+ | select 2 as id, 'a2' as name, price, ts from $tableName1
+ TIMESTAMP AS OF '$instant1'
+ """.stripMargin)
+ checkAnswer(
+ s"select id, name, price, ts, dt from $tableName2")(
+ Seq(1, "a1", 10.0, 1000, "2022-02-14"),
+ Seq(2, "a2", 10.0, 1000, "2022-02-15")
+ )
+ }
+ }
+}
diff --git a/hudi-spark-datasource/hudi-spark2/src/main/antlr4/imports/SqlBase.g4 b/hudi-spark-datasource/hudi-spark2/src/main/antlr4/imports/SqlBase.g4
index 2add2b030f538..a63457d842e5e 100644
--- a/hudi-spark-datasource/hudi-spark2/src/main/antlr4/imports/SqlBase.g4
+++ b/hudi-spark-datasource/hudi-spark2/src/main/antlr4/imports/SqlBase.g4
@@ -414,6 +414,11 @@ fromClause
: FROM relation (',' relation)* lateralView* pivotClause?
;
+temporalClause
+ : FOR? (SYSTEM_TIME | TIMESTAMP) AS OF timestamp=valueExpression
+ | FOR? (SYSTEM_VERSION | VERSION) AS OF version=(INTEGER_VALUE | STRING)
+ ;
+
aggregation
: GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* (
WITH kind=ROLLUP
@@ -510,7 +515,8 @@ identifierComment
;
relationPrimary
- : tableIdentifier sample? tableAlias #tableName
+ : tableIdentifier temporalClause?
+ sample? tableAlias #tableName
| '(' queryNoWith ')' sample? tableAlias #aliasedQuery
| '(' relation ')' sample? tableAlias #aliasedRelation
| inlineTable #inlineTableDefault2
@@ -778,6 +784,7 @@ nonReserved
| DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT
| DIRECTORY
| BOTH | LEADING | TRAILING
+ | SYSTEM_VERSION | VERSION | SYSTEM_TIME | TIMESTAMP
;
SELECT: 'SELECT';
@@ -1015,6 +1022,11 @@ ANTI: 'ANTI';
LOCAL: 'LOCAL';
INPATH: 'INPATH';
+SYSTEM_VERSION: 'SYSTEM_VERSION';
+VERSION: 'VERSION';
+SYSTEM_TIME: 'SYSTEM_TIME';
+TIMESTAMP: 'TIMESTAMP';
+
STRING
: '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
| '"' ( ~('"'|'\\') | ('\\' .) )* '"'
diff --git a/hudi-spark-datasource/hudi-spark2/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 b/hudi-spark-datasource/hudi-spark2/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4
index 6544d936243e9..a9fa73beef7bd 100644
--- a/hudi-spark-datasource/hudi-spark2/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4
+++ b/hudi-spark-datasource/hudi-spark2/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4
@@ -27,6 +27,32 @@ statement
: mergeInto #mergeIntoTable
| updateTableStmt #updateTable
| deleteTableStmt #deleteTable
+ | query #queryStatement
+ | createTableHeader ('(' colTypeList ')')? tableProvider
+ ((OPTIONS options=tablePropertyList) |
+ (PARTITIONED BY partitionColumnNames=identifierList) |
+ bucketSpec |
+ locationSpec |
+ (COMMENT comment=STRING) |
+ (TBLPROPERTIES tableProps=tablePropertyList))*
+ (AS? query)? #createTable
+ | createTableHeader ('(' columns=colTypeList ')')?
+ ((COMMENT comment=STRING) |
+ (PARTITIONED BY '(' partitionColumns=colTypeList ')') |
+ bucketSpec |
+ skewSpec |
+ rowFormat |
+ createFileFormat |
+ locationSpec |
+ (TBLPROPERTIES tableProps=tablePropertyList))*
+ (AS? query)? #createHiveTable
+ | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)?
+ VIEW (IF NOT EXISTS)? tableIdentifier
+ identifierCommentList? (COMMENT STRING)?
+ (PARTITIONED ON identifierList)?
+ (TBLPROPERTIES tablePropertyList)? AS query #createView
+ | ALTER VIEW tableIdentifier AS? query #alterViewQuery
+ | CACHE LAZY? TABLE tableIdentifier (AS? query)? #cacheTable
| .*? #passThrough
;
@@ -87,6 +113,7 @@ assignmentList
assignment
: key=qualifiedName EQ value=expression
;
+
qualifiedNameList
: qualifiedName (',' qualifiedName)*
;
diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeTravelRelation.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeTravelRelation.scala
new file mode 100644
index 0000000000000..67f942881c39d
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeTravelRelation.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+
+case class TimeTravelRelation(
+ table: LogicalPlan,
+ timestamp: Option[Expression],
+ version: Option[String]) extends Command {
+ override def children: Seq[LogicalPlan] = Seq.empty
+
+ override def output: Seq[Attribute] = Nil
+
+ override lazy val resolved: Boolean = false
+
+ def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = this
+}
diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/HoodieSpark2ExtendedSqlAstBuilder.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/HoodieSpark2ExtendedSqlAstBuilder.scala
index bbc9014fe804f..66ac10f503b4a 100644
--- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/HoodieSpark2ExtendedSqlAstBuilder.scala
+++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/HoodieSpark2ExtendedSqlAstBuilder.scala
@@ -16,30 +16,56 @@
*/
package org.apache.spark.sql.hudi.parser
-import org.antlr.v4.runtime.tree.ParseTree
-import org.apache.hudi.spark.sql.parser.HoodieSqlBaseBaseVisitor
+import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode}
+import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser._
+import org.apache.hudi.spark.sql.parser.{HoodieSqlBaseBaseVisitor, HoodieSqlBaseParser}
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _}
+import org.apache.spark.sql.catalyst.catalog.CatalogStorageFormat
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface, ParserUtils}
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.hudi.parser.unsafe.types.CalendarInterval
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.util.random.RandomSampler
+import java.sql.{Date, Timestamp}
+import java.util.Locale
+import javax.xml.bind.DatatypeConverter
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
/**
* The AstBuilder for HoodieSqlParser to parser the AST tree to Logical Plan.
* Here we only do the parser for the extended sql syntax. e.g MergeInto. For
* other sql syntax we use the delegate sql parser which is the SparkSqlParser.
*/
-class HoodieSpark2ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterface) extends HoodieSqlBaseBaseVisitor[AnyRef] with Logging {
+class HoodieSpark2ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterface)
+ extends HoodieSqlBaseBaseVisitor[AnyRef] with Logging {
import ParserUtils._
- override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
- ctx.statement().accept(this).asInstanceOf[LogicalPlan]
+ protected def typedVisit[T](ctx: ParseTree): T = {
+ ctx.accept(this).asInstanceOf[T]
+ }
+
+ /**
+ * Override the default behavior for all visit methods. This will only return a non-null result
+ * when the context has only one child. This is done because there is no generic method to
+ * combine the results of the context children. In all other cases null is returned.
+ */
+ override def visitChildren(node: RuleNode): AnyRef = {
+ if (node.getChildCount == 1) {
+ node.getChild(0).accept(this)
+ } else {
+ null
+ }
}
override def visitMergeIntoTable (ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) {
@@ -141,6 +167,13 @@ class HoodieSpark2ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterface
}
}
+ /**
+ * Parse a qualified name to a multipart name.
+ */
+ override def visitQualifiedName(ctx: QualifiedNameContext): Seq[String] = withOrigin(ctx) {
+ ctx.identifier.asScala.map(_.getText).toSeq
+ }
+
override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) {
val updateStmt = ctx.updateTableStmt()
val table = UnresolvedRelation(visitTableIdentifier(updateStmt.tableIdentifier()))
@@ -170,19 +203,9 @@ class HoodieSpark2ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterface
}
/**
- * Parse the expression tree to spark sql Expression.
- * Here we use the SparkSqlParser to do the parse.
- */
- private def expression(tree: ParseTree): Expression = {
- val expressionText = treeToString(tree)
- delegate.parseExpression(expressionText)
- }
-
- // ============== The following code is fork from org.apache.spark.sql.catalyst.parser.AstBuilder
- /**
- * If aliases specified in a FROM clause, create a subquery alias ([[SubqueryAlias]]) and
- * column aliases for a [[LogicalPlan]].
- */
+ * If aliases specified in a FROM clause, create a subquery alias ([[SubqueryAlias]]) and
+ * column aliases for a [[LogicalPlan]].
+ */
protected def mayApplyAliasPlan(tableAlias: TableAliasContext, plan: LogicalPlan): LogicalPlan = {
if (tableAlias.strictIdentifier != null) {
val alias = tableAlias.strictIdentifier.getText
@@ -196,35 +219,1798 @@ class HoodieSpark2ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterface
plan
}
}
+
/**
- * Parse a qualified name to a multipart name.
- */
- override def visitQualifiedName(ctx: QualifiedNameContext): Seq[String] = withOrigin(ctx) {
- ctx.identifier.asScala.map(_.getText)
+ * Parse the expression tree to spark sql Expression.
+ * Here we use the SparkSqlParser to do the parse.
+ */
+ private def expression(tree: ParseTree): Expression = {
+ val expressionText = treeToString(tree)
+ delegate.parseExpression(expressionText)
}
/**
- * Create a Sequence of Strings for a parenthesis enclosed alias list.
- */
- override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) {
- visitIdentifierSeq(ctx.identifierSeq)
+ * Create an aliased table reference. This is typically used in FROM clauses.
+ */
+ override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) {
+ val tableId = visitTableIdentifier(ctx.tableIdentifier())
+ val relation = UnresolvedRelation(tableId)
+ val table = mayApplyAliasPlan(
+ ctx.tableAlias, relation.optionalMap(ctx.temporalClause)(withTimeTravel))
+ table.optionalMap(ctx.sample)(withSample)
}
- /**
- * Create a Sequence of Strings for an identifier list.
- */
- override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) {
- ctx.identifier.asScala.map(_.getText)
+ private def withTimeTravel(
+ ctx: TemporalClauseContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ val v = ctx.version
+ val version = if (ctx.INTEGER_VALUE != null) {
+ Some(v.getText)
+ } else {
+ Option(v).map(string)
+ }
+ val timestamp = Option(ctx.timestamp).map(expression)
+ if (timestamp.exists(_.references.nonEmpty)) {
+ throw new ParseException(
+ "timestamp expression cannot refer to any columns", ctx.timestamp)
+ }
+ if (timestamp.exists(e => SubqueryExpression.hasSubquery(e))) {
+ throw new ParseException(
+ "timestamp expression cannot contain subqueries", ctx.timestamp)
+ }
+ TimeTravelRelation(plan, timestamp, version)
+ }
+
+ // ============== The following code is fork from org.apache.spark.sql.catalyst.parser.AstBuilder
+ override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
+ visit(ctx.statement).asInstanceOf[LogicalPlan]
+ }
+
+ override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) {
+ visitNamedExpression(ctx.namedExpression)
+ }
+
+ override def visitSingleTableIdentifier(
+ ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) {
+ visitTableIdentifier(ctx.tableIdentifier)
+ }
+
+ override def visitSingleFunctionIdentifier(
+ ctx: SingleFunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) {
+ visitFunctionIdentifier(ctx.functionIdentifier)
+ }
+
+ override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) {
+ visitSparkDataType(ctx.dataType)
+ }
+
+ override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = {
+ withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList)))
}
/* ********************************************************************************************
- * Table Identifier parsing
+ * Plan parsing
* ******************************************************************************************** */
+ protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree)
+
/**
- * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern.
- */
- override def visitTableIdentifier(ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) {
- TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText))
+ * Create a top-level plan with Common Table Expressions.
+ */
+ override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) {
+ val query = plan(ctx.queryNoWith)
+
+ // Apply CTEs
+ query.optional(ctx.ctes) {
+ val ctes = ctx.ctes.namedQuery.asScala.map { nCtx =>
+ val namedQuery = visitNamedQuery(nCtx)
+ (namedQuery.alias, namedQuery)
+ }
+ // Check for duplicate names.
+ checkDuplicateKeys(ctes, ctx)
+ With(query, ctes)
+ }
+ }
+
+ /**
+ * Create a named logical plan.
+ *
+ * This is only used for Common Table Expressions.
+ */
+ override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) {
+ SubqueryAlias(ctx.name.getText, plan(ctx.query))
+ }
+
+ /**
+ * Create a logical plan which allows for multiple inserts using one 'from' statement. These
+ * queries have the following SQL form:
+ * {{{
+ * [WITH cte...]?
+ * FROM src
+ * [INSERT INTO tbl1 SELECT *]+
+ * }}}
+ * For example:
+ * {{{
+ * FROM db.tbl1 A
+ * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5
+ * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12
+ * }}}
+ * This (Hive) feature cannot be combined with set-operators.
+ */
+ override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) {
+ val from = visitFromClause(ctx.fromClause)
+
+ // Build the insert clauses.
+ val inserts = ctx.multiInsertQueryBody.asScala.map {
+ body =>
+ validate(body.querySpecification.fromClause == null,
+ "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements",
+ body)
+
+ withQuerySpecification(body.querySpecification, from).
+ // Add organization statements.
+ optionalMap(body.queryOrganization)(withQueryResultClauses).
+ // Add insert.
+ optionalMap(body.insertInto())(withInsertInto)
+ }
+
+ // If there are multiple INSERTS just UNION them together into one query.
+ inserts match {
+ case Seq(query) => query
+ case queries => Union(queries)
+ }
+ }
+
+ /**
+ * Create a logical plan for a regular (single-insert) query.
+ */
+ override def visitSingleInsertQuery(
+ ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.queryTerm).
+ // Add organization statements.
+ optionalMap(ctx.queryOrganization)(withQueryResultClauses).
+ // Add insert.
+ optionalMap(ctx.insertInto())(withInsertInto)
+ }
+
+ /**
+ * Parameters used for writing query to a table:
+ * (tableIdentifier, partitionKeys, exists).
+ */
+ type InsertTableParams = (TableIdentifier, Map[String, Option[String]], Boolean)
+
+ /**
+ * Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider).
+ */
+ type InsertDirParams = (Boolean, CatalogStorageFormat, Option[String])
+
+ /**
+ * Add an
+ * {{{
+ * INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]?
+ * INSERT INTO [TABLE] tableIdentifier [partitionSpec]
+ * INSERT OVERWRITE [LOCAL] DIRECTORY STRING [rowFormat] [createFileFormat]
+ * INSERT OVERWRITE [LOCAL] DIRECTORY [STRING] tableProvider [OPTIONS tablePropertyList]
+ * }}}
+ * operation to logical plan
+ */
+ private def withInsertInto(
+ ctx: InsertIntoContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ ctx match {
+ case table: InsertIntoTableContext =>
+ val (tableIdent, partitionKeys, exists) = visitInsertIntoTable(table)
+ InsertIntoTable(UnresolvedRelation(tableIdent), partitionKeys, query, false, exists)
+ case table: InsertOverwriteTableContext =>
+ val (tableIdent, partitionKeys, exists) = visitInsertOverwriteTable(table)
+ InsertIntoTable(UnresolvedRelation(tableIdent), partitionKeys, query, true, exists)
+ case dir: InsertOverwriteDirContext =>
+ val (isLocal, storage, provider) = visitInsertOverwriteDir(dir)
+ InsertIntoDir(isLocal, storage, provider, query, overwrite = true)
+ case hiveDir: InsertOverwriteHiveDirContext =>
+ val (isLocal, storage, provider) = visitInsertOverwriteHiveDir(hiveDir)
+ InsertIntoDir(isLocal, storage, provider, query, overwrite = true)
+ case _ =>
+ throw new ParseException("Invalid InsertIntoContext", ctx)
+ }
+ }
+
+ /**
+ * Add an INSERT INTO TABLE operation to the logical plan.
+ */
+ override def visitInsertIntoTable(
+ ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) {
+ val tableIdent = visitTableIdentifier(ctx.tableIdentifier)
+ val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
+
+ (tableIdent, partitionKeys, false)
+ }
+
+ /**
+ * Add an INSERT OVERWRITE TABLE operation to the logical plan.
+ */
+ override def visitInsertOverwriteTable(
+ ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) {
+ assert(ctx.OVERWRITE() != null)
+ val tableIdent = visitTableIdentifier(ctx.tableIdentifier)
+ val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
+
+ val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty)
+ if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) {
+ throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " +
+ "partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx)
+ }
+
+ (tableIdent, partitionKeys, ctx.EXISTS() != null)
+ }
+
+ /**
+ * Write to a directory, returning a [[InsertIntoDir]] logical plan.
+ */
+ override def visitInsertOverwriteDir(
+ ctx: InsertOverwriteDirContext): InsertDirParams = withOrigin(ctx) {
+ throw new ParseException("INSERT OVERWRITE DIRECTORY is not supported", ctx)
+ }
+
+ /**
+ * Write to a directory, returning a [[InsertIntoDir]] logical plan.
+ */
+ override def visitInsertOverwriteHiveDir(
+ ctx: InsertOverwriteHiveDirContext): InsertDirParams = withOrigin(ctx) {
+ throw new ParseException("INSERT OVERWRITE DIRECTORY is not supported", ctx)
+ }
+
+ /**
+ * Create a partition specification map.
+ */
+ override def visitPartitionSpec(
+ ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) {
+ val parts = ctx.partitionVal.asScala.map { pVal =>
+ val name = pVal.identifier.getText
+ val value = Option(pVal.constant).map(visitStringConstant)
+ name -> value
+ }
+ // Before calling `toMap`, we check duplicated keys to avoid silently ignore partition values
+ // in partition spec like PARTITION(a='1', b='2', a='3'). The real semantical check for
+ // partition columns will be done in analyzer.
+ checkDuplicateKeys(parts, ctx)
+ parts.toMap
+ }
+
+ /**
+ * Create a partition specification map without optional values.
+ */
+ protected def visitNonOptionalPartitionSpec(
+ ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) {
+ visitPartitionSpec(ctx).map {
+ case (key, None) => throw new ParseException(s"Found an empty partition key '$key'.", ctx)
+ case (key, Some(value)) => key -> value
+ }
+ }
+
+ /**
+ * Convert a constant of any type into a string. This is typically used in DDL commands, and its
+ * main purpose is to prevent slight differences due to back to back conversions i.e.:
+ * String -> Literal -> String.
+ */
+ protected def visitStringConstant(ctx: ConstantContext): String = withOrigin(ctx) {
+ ctx match {
+ case s: StringLiteralContext => createString(s)
+ case o => o.getText
+ }
+ }
+
+ /**
+ * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These
+ * clauses determine the shape (ordering/partitioning/rows) of the query result.
+ */
+ private def withQueryResultClauses(
+ ctx: QueryOrganizationContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+
+ // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause.
+ val withOrder = if (
+ !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // ORDER BY ...
+ Sort(order.asScala.map(visitSortItem), global = true, query)
+ } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // SORT BY ...
+ Sort(sort.asScala.map(visitSortItem), global = false, query)
+ } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
+ // DISTRIBUTE BY ...
+ withRepartitionByExpression(ctx, expressionList(distributeBy), query)
+ } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
+ // SORT BY ... DISTRIBUTE BY ...
+ Sort(
+ sort.asScala.map(visitSortItem),
+ global = false,
+ withRepartitionByExpression(ctx, expressionList(distributeBy), query))
+ } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) {
+ // CLUSTER BY ...
+ val expressions = expressionList(clusterBy)
+ Sort(
+ expressions.map(SortOrder(_, Ascending)),
+ global = false,
+ withRepartitionByExpression(ctx, expressions, query))
+ } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // [EMPTY]
+ query
+ } else {
+ throw new ParseException(
+ "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx)
+ }
+
+ // WINDOWS
+ val withWindow = withOrder.optionalMap(windows)(withWindows)
+
+ // LIMIT
+ // - LIMIT ALL is the same as omitting the LIMIT clause
+ withWindow.optional(limit) {
+ Limit(typedVisit(limit), withWindow)
+ }
+ }
+
+ /**
+ * Create a clause for DISTRIBUTE BY.
+ */
+ protected def withRepartitionByExpression(
+ ctx: QueryOrganizationContext,
+ expressions: Seq[Expression],
+ query: LogicalPlan): LogicalPlan = {
+ throw new ParseException("DISTRIBUTE BY is not supported", ctx)
+ }
+
+ /**
+ * Create a logical plan using a query specification.
+ */
+ override def visitQuerySpecification(
+ ctx: QuerySpecificationContext): LogicalPlan = withOrigin(ctx) {
+ val from = OneRowRelation().optional(ctx.fromClause) {
+ visitFromClause(ctx.fromClause)
+ }
+ withQuerySpecification(ctx, from)
+ }
+
+ /**
+ * Add a query specification to a logical plan. The query specification is the core of the logical
+ * plan, this is where sourcing (FROM clause), transforming (SELECT TRANSFORM/MAP/REDUCE),
+ * projection (SELECT), aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place.
+ *
+ * Note that query hints are ignored (both by the parser and the builder).
+ */
+ private def withQuerySpecification(
+ ctx: QuerySpecificationContext,
+ relation: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+
+ // WHERE
+ def filter(ctx: BooleanExpressionContext, plan: LogicalPlan): LogicalPlan = {
+ Filter(expression(ctx), plan)
+ }
+
+ def withHaving(ctx: BooleanExpressionContext, plan: LogicalPlan): LogicalPlan = {
+ // Note that we add a cast to non-predicate expressions. If the expression itself is
+ // already boolean, the optimizer will get rid of the unnecessary cast.
+ val predicate = expression(ctx) match {
+ case p: Predicate => p
+ case e => Cast(e, BooleanType)
+ }
+ Filter(predicate, plan)
+ }
+
+
+ // Expressions.
+ val expressions = Option(namedExpressionSeq).toSeq
+ .flatMap(_.namedExpression.asScala)
+ .map(typedVisit[Expression])
+
+ // Create either a transform or a regular query.
+ val specType = Option(kind).map(_.getType).getOrElse(HoodieSqlBaseParser.SELECT)
+ specType match {
+ case HoodieSqlBaseParser.MAP | HoodieSqlBaseParser.REDUCE | HoodieSqlBaseParser.TRANSFORM =>
+ // Transform
+
+ // Add where.
+ val withFilter = relation.optionalMap(where)(filter)
+
+ // Create the attributes.
+ val (attributes, schemaLess) = if (colTypeList != null) {
+ // Typed return columns.
+ (createSchema(colTypeList).toAttributes, false)
+ } else if (identifierSeq != null) {
+ // Untyped return columns.
+ val attrs = visitIdentifierSeq(identifierSeq).map { name =>
+ AttributeReference(name, StringType, nullable = true)()
+ }
+ (attrs, false)
+ } else {
+ (Seq(AttributeReference("key", StringType)(),
+ AttributeReference("value", StringType)()), true)
+ }
+
+ // Create the transform.
+ ScriptTransformation(
+ expressions,
+ string(script),
+ attributes,
+ withFilter,
+ withScriptIOSchema(
+ ctx, inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess))
+
+ case HoodieSqlBaseParser.SELECT =>
+ // Regular select
+
+ // Add lateral views.
+ val withLateralView = ctx.lateralView.asScala.foldLeft(relation)(withGenerate)
+
+ // Add where.
+ val withFilter = withLateralView.optionalMap(where)(filter)
+
+ // Add aggregation or a project.
+ val namedExpressions = expressions.map {
+ case e: NamedExpression => e
+ case e: Expression => UnresolvedAlias(e)
+ }
+
+ def createProject() = if (namedExpressions.nonEmpty) {
+ Project(namedExpressions, withFilter)
+ } else {
+ withFilter
+ }
+
+ val withProject = if (aggregation == null && having != null) {
+ if (conf.getConf(SQLConf.LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE)) {
+ // If the legacy conf is set, treat HAVING without GROUP BY as WHERE.
+ withHaving(having, createProject())
+ } else {
+ // According to SQL standard, HAVING without GROUP BY means global aggregate.
+ withHaving(having, Aggregate(Nil, namedExpressions, withFilter))
+ }
+ } else if (aggregation != null) {
+ val aggregate = withAggregation(aggregation, namedExpressions, withFilter)
+ aggregate.optionalMap(having)(withHaving)
+ } else {
+ // When hitting this branch, `having` must be null.
+ createProject()
+ }
+
+ // Distinct
+ val withDistinct = if (setQuantifier() != null && setQuantifier().DISTINCT() != null) {
+ Distinct(withProject)
+ } else {
+ withProject
+ }
+
+ // Window
+ val withWindow = withDistinct.optionalMap(windows)(withWindows)
+
+ // Hint
+ hints.asScala.foldRight(withWindow)(withHints)
+ }
+ }
+
+ /**
+ * Create a (Hive based) [[ScriptInputOutputSchema]].
+ */
+ protected def withScriptIOSchema(
+ ctx: QuerySpecificationContext,
+ inRowFormat: RowFormatContext,
+ recordWriter: Token,
+ outRowFormat: RowFormatContext,
+ recordReader: Token,
+ schemaLess: Boolean): ScriptInputOutputSchema = {
+ throw new ParseException("Script Transform is not supported", ctx)
+ }
+
+ /**
+ * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma
+ * separated) relations here, these get converted into a single plan by condition-less inner join.
+ */
+ override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) {
+ val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) =>
+ val right = plan(relation.relationPrimary)
+ val join = right.optionalMap(left)(Join(_, _, Inner, None))
+ withJoinRelations(join, relation)
+ }
+ if (ctx.pivotClause() != null) {
+ if (!ctx.lateralView.isEmpty) {
+ throw new ParseException("LATERAL cannot be used together with PIVOT in FROM clause", ctx)
+ }
+ withPivot(ctx.pivotClause, from)
+ } else {
+ ctx.lateralView.asScala.foldLeft(from)(withGenerate)
+ }
+ }
+
+ /**
+ * Connect two queries by a Set operator.
+ *
+ * Supported Set operators are:
+ * - UNION [ DISTINCT | ALL ]
+ * - EXCEPT [ DISTINCT | ALL ]
+ * - MINUS [ DISTINCT | ALL ]
+ * - INTERSECT [DISTINCT | ALL]
+ */
+ override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) {
+ val left = plan(ctx.left)
+ val right = plan(ctx.right)
+ val all = Option(ctx.setQuantifier()).exists(_.ALL != null)
+ ctx.operator.getType match {
+ case HoodieSqlBaseParser.UNION if all =>
+ Union(left, right)
+ case HoodieSqlBaseParser.UNION =>
+ Distinct(Union(left, right))
+ case HoodieSqlBaseParser.INTERSECT if all =>
+ Intersect(left, right, isAll = true)
+ case HoodieSqlBaseParser.INTERSECT =>
+ Intersect(left, right, isAll = false)
+ case HoodieSqlBaseParser.EXCEPT if all =>
+ Except(left, right, isAll = true)
+ case HoodieSqlBaseParser.EXCEPT =>
+ Except(left, right, isAll = false)
+ case HoodieSqlBaseParser.SETMINUS if all =>
+ Except(left, right, isAll = true)
+ case HoodieSqlBaseParser.SETMINUS =>
+ Except(left, right, isAll = false)
+ }
+ }
+
+ /**
+ * Add a [[WithWindowDefinition]] operator to a logical plan.
+ */
+ private def withWindows(
+ ctx: WindowsContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // Collect all window specifications defined in the WINDOW clause.
+ val baseWindowMap = ctx.namedWindow.asScala.map {
+ wCtx =>
+ (wCtx.identifier.getText, typedVisit[WindowSpec](wCtx.windowSpec))
+ }.toMap
+
+ // Handle cases like
+ // window w1 as (partition by p_mfgr order by p_name
+ // range between 2 preceding and 2 following),
+ // w2 as w1
+ val windowMapView = baseWindowMap.mapValues {
+ case WindowSpecReference(name) =>
+ baseWindowMap.get(name) match {
+ case Some(spec: WindowSpecDefinition) =>
+ spec
+ case Some(ref) =>
+ throw new ParseException(s"Window reference '$name' is not a window specification", ctx)
+ case None =>
+ throw new ParseException(s"Cannot resolve window reference '$name'", ctx)
+ }
+ case spec: WindowSpecDefinition => spec
+ }
+
+ // Note that mapValues creates a view instead of materialized map. We force materialization by
+ // mapping over identity.
+ WithWindowDefinition(windowMapView.map(identity), query)
+ }
+
+ /**
+ * Add an [[Aggregate]] or [[GroupingSets]] to a logical plan.
+ */
+ private def withAggregation(
+ ctx: AggregationContext,
+ selectExpressions: Seq[NamedExpression],
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ val groupByExpressions = expressionList(ctx.groupingExpressions)
+
+ if (ctx.GROUPING != null) {
+ // GROUP BY .... GROUPING SETS (...)
+ val selectedGroupByExprs =
+ ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)))
+ GroupingSets(selectedGroupByExprs, groupByExpressions, query, selectExpressions)
+ } else {
+ // GROUP BY .... (WITH CUBE | WITH ROLLUP)?
+ val mappedGroupByExpressions = if (ctx.CUBE != null) {
+ Seq(Cube(groupByExpressions))
+ } else if (ctx.ROLLUP != null) {
+ Seq(Rollup(groupByExpressions))
+ } else {
+ groupByExpressions
+ }
+ Aggregate(mappedGroupByExpressions, selectExpressions, query)
+ }
+ }
+
+ /**
+ * Add [[UnresolvedHint]]s to a logical plan.
+ */
+ private def withHints(
+ ctx: HintContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ var plan = query
+ ctx.hintStatements.asScala.reverse.foreach { case stmt =>
+ plan = UnresolvedHint(stmt.hintName.getText, stmt.parameters.asScala.map(expression), plan)
+ }
+ plan
+ }
+
+ /**
+ * Add a [[Pivot]] to a logical plan.
+ */
+ private def withPivot(
+ ctx: PivotClauseContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ val aggregates = Option(ctx.aggregates).toSeq
+ .flatMap(_.namedExpression.asScala)
+ .map(typedVisit[Expression])
+ val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) {
+ UnresolvedAttribute.quoted(ctx.pivotColumn.identifier.getText)
+ } else {
+ CreateStruct(
+ ctx.pivotColumn.identifiers.asScala.map(
+ identifier => UnresolvedAttribute.quoted(identifier.getText)))
+ }
+ val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue)
+ Pivot(None, pivotColumn, pivotValues, aggregates, query)
+ }
+
+ /**
+ * Create a Pivot column value with or without an alias.
+ */
+ override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.expression)
+ if (ctx.identifier != null) {
+ Alias(e, ctx.identifier.getText)()
+ } else {
+ e
+ }
+ }
+
+ /**
+ * Add a [[Generate]] (Lateral View) to a logical plan.
+ */
+ private def withGenerate(
+ query: LogicalPlan,
+ ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) {
+ val expressions = expressionList(ctx.expression)
+ Generate(
+ UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions),
+ unrequiredChildIndex = Nil,
+ outer = ctx.OUTER != null,
+ Some(ctx.tblName.getText.toLowerCase),
+ ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply),
+ query)
+ }
+
+ /**
+ * Create a single relation referenced in a FROM clause. This method is used when a part of the
+ * join condition is nested, for example:
+ * {{{
+ * select * from t1 join (t2 cross join t3) on col1 = col2
+ * }}}
+ */
+ override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) {
+ withJoinRelations(plan(ctx.relationPrimary), ctx)
+ }
+
+ /**
+ * Join one more [[LogicalPlan]]s to the current logical plan.
+ */
+ private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = {
+ ctx.joinRelation.asScala.foldLeft(base) { (left, join) =>
+ withOrigin(join) {
+ val baseJoinType = join.joinType match {
+ case null => Inner
+ case jt if jt.CROSS != null => Cross
+ case jt if jt.FULL != null => FullOuter
+ case jt if jt.SEMI != null => LeftSemi
+ case jt if jt.ANTI != null => LeftAnti
+ case jt if jt.LEFT != null => LeftOuter
+ case jt if jt.RIGHT != null => RightOuter
+ case _ => Inner
+ }
+
+ // Resolve the join type and join condition
+ val (joinType, condition) = Option(join.joinCriteria) match {
+ case Some(c) if c.USING != null =>
+ (UsingJoin(baseJoinType, c.identifier.asScala.map(_.getText)), None)
+ case Some(c) if c.booleanExpression != null =>
+ (baseJoinType, Option(expression(c.booleanExpression)))
+ case None if join.NATURAL != null =>
+ if (baseJoinType == Cross) {
+ throw new ParseException("NATURAL CROSS JOIN is not supported", ctx)
+ }
+ (NaturalJoin(baseJoinType), None)
+ case None =>
+ (baseJoinType, None)
+ }
+ Join(left, plan(join.right), joinType, condition)
+ }
+ }
+ }
+
+ /**
+ * Add a [[Sample]] to a logical plan.
+ *
+ * This currently supports the following sampling methods:
+ * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows.
+ * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages
+ * are defined as a number between 0 and 100.
+ * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction.
+ */
+ private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // Create a sampled plan if we need one.
+ def sample(fraction: Double): Sample = {
+ // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling
+ // function takes X PERCENT as the input and the range of X is [0, 100], we need to
+ // adjust the fraction.
+ val eps = RandomSampler.roundingEpsilon
+ validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps,
+ s"Sampling fraction ($fraction) must be on interval [0, 1]",
+ ctx)
+ Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)
+ }
+
+ if (ctx.sampleMethod() == null) {
+ throw new ParseException("TABLESAMPLE does not accept empty inputs.", ctx)
+ }
+
+ ctx.sampleMethod() match {
+ case ctx: SampleByRowsContext =>
+ Limit(expression(ctx.expression), query)
+
+ case ctx: SampleByPercentileContext =>
+ val fraction = ctx.percentage.getText.toDouble
+ val sign = if (ctx.negativeSign == null) 1 else -1
+ sample(sign * fraction / 100.0d)
+
+ case ctx: SampleByBytesContext =>
+ val bytesStr = ctx.bytes.getText
+ if (bytesStr.matches("[0-9]+[bBkKmMgG]")) {
+ throw new ParseException("TABLESAMPLE(byteLengthLiteral) is not supported", ctx)
+ } else {
+ throw new ParseException(
+ bytesStr + " is not a valid byte length literal, " +
+ "expected syntax: DIGIT+ ('B' | 'K' | 'M' | 'G')", ctx)
+ }
+
+ case ctx: SampleByBucketContext if ctx.ON() != null =>
+ if (ctx.identifier != null) {
+ throw new ParseException(
+ "TABLESAMPLE(BUCKET x OUT OF y ON colname) is not supported", ctx)
+ } else {
+ throw new ParseException(
+ "TABLESAMPLE(BUCKET x OUT OF y ON function) is not supported", ctx)
+ }
+
+ case ctx: SampleByBucketContext =>
+ sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble)
+ }
+ }
+
+ /**
+ * Create a logical plan for a sub-query.
+ */
+ override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.queryNoWith)
+ }
+
+ /**
+ * Create an un-aliased table reference. This is typically used for top-level table references,
+ * for example:
+ * {{{
+ * INSERT INTO db.tbl2
+ * TABLE db.tbl1
+ * }}}
+ */
+ override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) {
+ UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier))
+ }
+
+ /**
+ * Create a table-valued function call with arguments, e.g. range(1000)
+ */
+ override def visitTableValuedFunction(ctx: TableValuedFunctionContext)
+ : LogicalPlan = withOrigin(ctx) {
+ val func = ctx.functionTable
+ val aliases = if (func.tableAlias.identifierList != null) {
+ visitIdentifierList(func.tableAlias.identifierList)
+ } else {
+ Seq.empty
+ }
+
+ val tvf = UnresolvedTableValuedFunction(
+ func.identifier.getText, func.expression.asScala.map(expression), aliases)
+ tvf.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan)
+ }
+
+ /**
+ * Create an inline table (a virtual table in Hive parlance).
+ */
+ override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) {
+ // Get the backing expressions.
+ val rows = ctx.expression.asScala.map { e =>
+ expression(e) match {
+ // inline table comes in two styles:
+ // style 1: values (1), (2), (3) -- multiple columns are supported
+ // style 2: values 1, 2, 3 -- only a single column is supported here
+ case struct: CreateNamedStruct => struct.valExprs // style 1
+ case child => Seq(child) // style 2
+ }
+ }
+
+ val aliases = if (ctx.tableAlias.identifierList != null) {
+ visitIdentifierList(ctx.tableAlias.identifierList)
+ } else {
+ Seq.tabulate(rows.head.size)(i => s"col${i + 1}")
+ }
+
+ val table = UnresolvedInlineTable(aliases, rows)
+ table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a join relation. This is practically the same as
+ * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different
+ * hooks. We could add alias names for output columns, for example:
+ * {{{
+ * SELECT a, b, c, d FROM (src1 s1 INNER JOIN src2 s2 ON s1.id = s2.id) dst(a, b, c, d)
+ * }}}
+ */
+ override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) {
+ val relation = plan(ctx.relation).optionalMap(ctx.sample)(withSample)
+ mayApplyAliasPlan(ctx.tableAlias, relation)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as
+ * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different
+ * hooks. We could add alias names for output columns, for example:
+ * {{{
+ * SELECT col1, col2 FROM testData AS t(col1, col2)
+ * }}}
+ */
+ override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) {
+ val relation = plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample)
+ if (ctx.tableAlias.strictIdentifier == null) {
+ // For un-aliased subqueries, use a default alias name that is not likely to conflict with
+ // normal subquery names, so that parent operators can only access the columns in subquery by
+ // unqualified names. Users can still use this special qualifier to access columns if they
+ // know it, but that's not recommended.
+ SubqueryAlias("__auto_generated_subquery_name", relation)
+ } else {
+ mayApplyAliasPlan(ctx.tableAlias, relation)
+ }
+ }
+
+ /**
+ * Create an alias ([[SubqueryAlias]]) for a [[LogicalPlan]].
+ */
+ private def aliasPlan(alias: ParserRuleContext, plan: LogicalPlan): LogicalPlan = {
+ SubqueryAlias(alias.getText, plan)
+ }
+
+ /**
+ * Create a Sequence of Strings for a parenthesis enclosed alias list.
+ */
+ override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) {
+ visitIdentifierSeq(ctx.identifierSeq)
+ }
+
+ /**
+ * Create a Sequence of Strings for an identifier list.
+ */
+ override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) {
+ ctx.identifier.asScala.map(_.getText)
+ }
+
+ /* ********************************************************************************************
+ * Table Identifier parsing
+ * ******************************************************************************************** */
+ /**
+ * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern.
+ */
+ override def visitTableIdentifier(
+ ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) {
+ TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText))
+ }
+
+ /**
+ * Create a [[FunctionIdentifier]] from a 'functionName' or 'databaseName'.'functionName' pattern.
+ */
+ override def visitFunctionIdentifier(
+ ctx: FunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) {
+ FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText))
+ }
+
+ /* ********************************************************************************************
+ * Expression parsing
+ * ******************************************************************************************** */
+ /**
+ * Create sequence of expressions from the given sequence of contexts.
+ */
+ private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = {
+ trees.asScala.map(expression)
+ }
+
+ /**
+ * Create a star (i.e. all) expression; this selects all elements (in the specified object).
+ * Both un-targeted (global) and targeted aliases are supported.
+ */
+ override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) {
+ UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText)))
+ }
+
+ /**
+ * Create an aliased expression if an alias is specified. Both single and multi-aliases are
+ * supported.
+ */
+ override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.expression)
+ if (ctx.identifier != null) {
+ Alias(e, ctx.identifier.getText)()
+ } else if (ctx.identifierList != null) {
+ MultiAlias(e, visitIdentifierList(ctx.identifierList))
+ } else {
+ e
+ }
+ }
+
+ /**
+ * Combine a number of boolean expressions into a balanced expression tree. These expressions are
+ * either combined by a logical [[And]] or a logical [[Or]].
+ *
+ * A balanced binary tree is created because regular left recursive trees cause considerable
+ * performance degradations and can cause stack overflows.
+ */
+ override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) {
+ val expressionType = ctx.operator.getType
+ val expressionCombiner = expressionType match {
+ case HoodieSqlBaseParser.AND => And.apply _
+ case HoodieSqlBaseParser.OR => Or.apply _
+ }
+
+ // Collect all similar left hand contexts.
+ val contexts = ArrayBuffer(ctx.right)
+ var current = ctx.left
+ def collectContexts: Boolean = current match {
+ case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType =>
+ contexts += lbc.right
+ current = lbc.left
+ true
+ case _ =>
+ contexts += current
+ false
+ }
+ while (collectContexts) {
+ // No body - all updates take place in the collectContexts.
+ }
+
+ // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them
+ // into expressions.
+ val expressions = contexts.reverseMap(expression)
+
+ // Create a balanced tree.
+ def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match {
+ case 0 =>
+ expressions(low)
+ case 1 =>
+ expressionCombiner(expressions(low), expressions(high))
+ case x =>
+ val mid = low + x / 2
+ expressionCombiner(
+ reduceToExpressionTree(low, mid),
+ reduceToExpressionTree(mid + 1, high))
+ }
+ reduceToExpressionTree(0, expressions.size - 1)
+ }
+
+ /**
+ * Invert a boolean expression.
+ */
+ override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) {
+ Not(expression(ctx.booleanExpression()))
+ }
+
+ /**
+ * Create a filtering correlated sub-query (EXISTS).
+ */
+ override def visitExists(ctx: ExistsContext): Expression = {
+ Exists(plan(ctx.query))
+ }
+
+ /**
+ * Create a comparison expression. This compares two expressions. The following comparison
+ * operators are supported:
+ * - Equal: '=' or '=='
+ * - Null-safe Equal: '<=>'
+ * - Not Equal: '<>' or '!='
+ * - Less than: '<'
+ * - Less then or Equal: '<='
+ * - Greater than: '>'
+ * - Greater then or Equal: '>='
+ */
+ override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) {
+ val left = expression(ctx.left)
+ val right = expression(ctx.right)
+ val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode]
+ operator.getSymbol.getType match {
+ case HoodieSqlBaseParser.EQ =>
+ EqualTo(left, right)
+ case HoodieSqlBaseParser.NSEQ =>
+ EqualNullSafe(left, right)
+ case HoodieSqlBaseParser.NEQ | HoodieSqlBaseParser.NEQJ =>
+ Not(EqualTo(left, right))
+ case HoodieSqlBaseParser.LT =>
+ LessThan(left, right)
+ case HoodieSqlBaseParser.LTE =>
+ LessThanOrEqual(left, right)
+ case HoodieSqlBaseParser.GT =>
+ GreaterThan(left, right)
+ case HoodieSqlBaseParser.GTE =>
+ GreaterThanOrEqual(left, right)
+ }
+ }
+
+ /**
+ * Create a predicated expression. A predicated expression is a normal expression with a
+ * predicate attached to it, for example:
+ * {{{
+ * a + 1 IS NULL
+ * }}}
+ */
+ override def visitPredicated(ctx: PredicatedContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.valueExpression)
+ if (ctx.predicate != null) {
+ withPredicate(e, ctx.predicate)
+ } else {
+ e
+ }
+ }
+
+ /**
+ * Add a predicate to the given expression. Supported expressions are:
+ * - (NOT) BETWEEN
+ * - (NOT) IN
+ * - (NOT) LIKE
+ * - (NOT) RLIKE
+ * - IS (NOT) NULL.
+ * - IS (NOT) DISTINCT FROM
+ */
+ private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) {
+ // Invert a predicate if it has a valid NOT clause.
+ def invertIfNotDefined(e: Expression): Expression = ctx.NOT match {
+ case null => e
+ case not => Not(e)
+ }
+
+ def getValueExpressions(e: Expression): Seq[Expression] = e match {
+ case c: CreateNamedStruct => c.valExprs
+ case other => Seq(other)
+ }
+
+ // Create the predicate.
+ ctx.kind.getType match {
+ case HoodieSqlBaseParser.BETWEEN =>
+ // BETWEEN is translated to lower <= e && e <= upper
+ invertIfNotDefined(And(
+ GreaterThanOrEqual(e, expression(ctx.lower)),
+ LessThanOrEqual(e, expression(ctx.upper))))
+ case HoodieSqlBaseParser.IN if ctx.query != null =>
+ invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query))))
+ case HoodieSqlBaseParser.IN =>
+ invertIfNotDefined(In(e, ctx.expression.asScala.map(expression)))
+ case HoodieSqlBaseParser.LIKE =>
+ invertIfNotDefined(Like(e, expression(ctx.pattern)))
+ case HoodieSqlBaseParser.RLIKE =>
+ invertIfNotDefined(RLike(e, expression(ctx.pattern)))
+ case HoodieSqlBaseParser.NULL if ctx.NOT != null =>
+ IsNotNull(e)
+ case HoodieSqlBaseParser.NULL =>
+ IsNull(e)
+ case HoodieSqlBaseParser.DISTINCT if ctx.NOT != null =>
+ EqualNullSafe(e, expression(ctx.right))
+ case HoodieSqlBaseParser.DISTINCT =>
+ Not(EqualNullSafe(e, expression(ctx.right)))
+ }
+ }
+
+ /**
+ * Create a binary arithmetic expression. The following arithmetic operators are supported:
+ * - Multiplication: '*'
+ * - Division: '/'
+ * - Hive Long Division: 'DIV'
+ * - Modulo: '%'
+ * - Addition: '+'
+ * - Subtraction: '-'
+ * - Binary AND: '&'
+ * - Binary XOR
+ * - Binary OR: '|'
+ */
+ override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) {
+ val left = expression(ctx.left)
+ val right = expression(ctx.right)
+ ctx.operator.getType match {
+ case HoodieSqlBaseParser.ASTERISK =>
+ Multiply(left, right)
+ case HoodieSqlBaseParser.SLASH =>
+ Divide(left, right)
+ case HoodieSqlBaseParser.PERCENT =>
+ Remainder(left, right)
+ case HoodieSqlBaseParser.DIV =>
+ Cast(Divide(left, right), LongType)
+ case HoodieSqlBaseParser.PLUS =>
+ Add(left, right)
+ case HoodieSqlBaseParser.MINUS =>
+ Subtract(left, right)
+ case HoodieSqlBaseParser.CONCAT_PIPE =>
+ Concat(left :: right :: Nil)
+ case HoodieSqlBaseParser.AMPERSAND =>
+ BitwiseAnd(left, right)
+ case HoodieSqlBaseParser.HAT =>
+ BitwiseXor(left, right)
+ case HoodieSqlBaseParser.PIPE =>
+ BitwiseOr(left, right)
+ }
+ }
+
+ /**
+ * Create a unary arithmetic expression. The following arithmetic operators are supported:
+ * - Plus: '+'
+ * - Minus: '-'
+ * - Bitwise Not: '~'
+ */
+ override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) {
+ val value = expression(ctx.valueExpression)
+ ctx.operator.getType match {
+ case HoodieSqlBaseParser.PLUS =>
+ value
+ case HoodieSqlBaseParser.MINUS =>
+ UnaryMinus(value)
+ case HoodieSqlBaseParser.TILDE =>
+ BitwiseNot(value)
+ }
+ }
+
+ /**
+ * Create a [[Cast]] expression.
+ */
+ override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) {
+ Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType))
+ }
+
+ /**
+ * Create a [[CreateStruct]] expression.
+ */
+ override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) {
+ CreateStruct(ctx.argument.asScala.map(expression))
+ }
+
+ /**
+ * Create a [[First]] expression.
+ */
+ override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) {
+ val ignoreNullsExpr = ctx.IGNORE != null
+ First(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression()
+ }
+
+ /**
+ * Create a [[Last]] expression.
+ */
+ override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) {
+ val ignoreNullsExpr = ctx.IGNORE != null
+ Last(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression()
+ }
+
+ /**
+ * Create a Position expression.
+ */
+ override def visitPosition(ctx: PositionContext): Expression = withOrigin(ctx) {
+ new StringLocate(expression(ctx.substr), expression(ctx.str))
+ }
+
+ /**
+ * Create a Extract expression.
+ */
+ override def visitExtract(ctx: ExtractContext): Expression = withOrigin(ctx) {
+ ctx.field.getText.toUpperCase(Locale.ROOT) match {
+ case "YEAR" =>
+ Year(expression(ctx.source))
+ case "QUARTER" =>
+ Quarter(expression(ctx.source))
+ case "MONTH" =>
+ Month(expression(ctx.source))
+ case "WEEK" =>
+ WeekOfYear(expression(ctx.source))
+ case "DAY" =>
+ DayOfMonth(expression(ctx.source))
+ case "DAYOFWEEK" =>
+ DayOfWeek(expression(ctx.source))
+ case "HOUR" =>
+ Hour(expression(ctx.source))
+ case "MINUTE" =>
+ Minute(expression(ctx.source))
+ case "SECOND" =>
+ Second(expression(ctx.source))
+ case other =>
+ throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx)
+ }
+ }
+
+ /**
+ * Create a (windowed) Function expression.
+ */
+ override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) {
+ def replaceFunctions(
+ funcID: FunctionIdentifier,
+ ctx: FunctionCallContext): FunctionIdentifier = {
+ val opt = ctx.trimOption
+ if (opt != null) {
+ if (ctx.qualifiedName.getText.toLowerCase(Locale.ROOT) != "trim") {
+ throw new ParseException(s"The specified function ${ctx.qualifiedName.getText} " +
+ s"doesn't support with option ${opt.getText}.", ctx)
+ }
+ opt.getType match {
+ case HoodieSqlBaseParser.BOTH => funcID
+ case HoodieSqlBaseParser.LEADING => funcID.copy(funcName = "ltrim")
+ case HoodieSqlBaseParser.TRAILING => funcID.copy(funcName = "rtrim")
+ case _ => throw new ParseException("Function trim doesn't support with " +
+ s"type ${opt.getType}. Please use BOTH, LEADING or Trailing as trim type", ctx)
+ }
+ } else {
+ funcID
+ }
+ }
+ // Create the function call.
+ val name = ctx.qualifiedName.getText
+ val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null)
+ val arguments = ctx.argument.asScala.map(expression) match {
+ case Seq(UnresolvedStar(None))
+ if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct =>
+ // Transform COUNT(*) into COUNT(1).
+ Seq(Literal(1))
+ case expressions =>
+ expressions
+ }
+ val funcId = replaceFunctions(visitFunctionName(ctx.qualifiedName), ctx)
+ val function = UnresolvedFunction(funcId, arguments, isDistinct)
+
+
+ // Check if the function is evaluated in a windowed context.
+ ctx.windowSpec match {
+ case spec: WindowRefContext =>
+ UnresolvedWindowExpression(function, visitWindowRef(spec))
+ case spec: WindowDefContext =>
+ WindowExpression(function, visitWindowDef(spec))
+ case _ => function
+ }
+ }
+
+ /**
+ * Create a function database (optional) and name pair.
+ */
+ protected def visitFunctionName(ctx: QualifiedNameContext): FunctionIdentifier = {
+ ctx.identifier().asScala.map(_.getText) match {
+ case Seq(db, fn) => FunctionIdentifier(fn, Option(db))
+ case Seq(fn) => FunctionIdentifier(fn, None)
+ case other => throw new ParseException(s"Unsupported function name '${ctx.getText}'", ctx)
+ }
+ }
+
+ /**
+ * Create an [[LambdaFunction]].
+ */
+ override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) {
+ val arguments = ctx.IDENTIFIER().asScala.map { name =>
+ UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts)
+ }
+ val function = expression(ctx.expression).transformUp {
+ case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts)
+ }
+ LambdaFunction(function, arguments)
+ }
+
+ /**
+ * Create a reference to a window frame, i.e. [[WindowSpecReference]].
+ */
+ override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) {
+ WindowSpecReference(ctx.identifier.getText)
+ }
+
+ /**
+ * Create a window definition, i.e. [[WindowSpecDefinition]].
+ */
+ override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) {
+ // CLUSTER BY ... | PARTITION BY ... ORDER BY ...
+ val partition = ctx.partition.asScala.map(expression)
+ val order = ctx.sortItem.asScala.map(visitSortItem)
+
+ // RANGE/ROWS BETWEEN ...
+ val frameSpecOption = Option(ctx.windowFrame).map { frame =>
+ val frameType = frame.frameType.getType match {
+ case HoodieSqlBaseParser.RANGE => RangeFrame
+ case HoodieSqlBaseParser.ROWS => RowFrame
+ }
+
+ SpecifiedWindowFrame(
+ frameType,
+ visitFrameBound(frame.start),
+ Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow))
+ }
+
+ WindowSpecDefinition(
+ partition,
+ order,
+ frameSpecOption.getOrElse(UnspecifiedFrame))
+ }
+
+ /**
+ * Create or resolve a frame boundary expressions.
+ */
+ override def visitFrameBound(ctx: FrameBoundContext): Expression = withOrigin(ctx) {
+ def value: Expression = {
+ val e = expression(ctx.expression)
+ validate(e.resolved && e.foldable, "Frame bound value must be a literal.", ctx)
+ e
+ }
+
+ ctx.boundType.getType match {
+ case HoodieSqlBaseParser.PRECEDING if ctx.UNBOUNDED != null =>
+ UnboundedPreceding
+ case HoodieSqlBaseParser.PRECEDING =>
+ UnaryMinus(value)
+ case HoodieSqlBaseParser.CURRENT =>
+ CurrentRow
+ case HoodieSqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null =>
+ UnboundedFollowing
+ case HoodieSqlBaseParser.FOLLOWING =>
+ value
+ }
+ }
+
+ /**
+ * Create a [[CreateStruct]] expression.
+ */
+ override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) {
+ CreateStruct(ctx.namedExpression().asScala.map(expression))
+ }
+
+ /**
+ * Create a [[ScalarSubquery]] expression.
+ */
+ override def visitSubqueryExpression(
+ ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) {
+ ScalarSubquery(plan(ctx.query))
+ }
+
+ /**
+ * Create a value based [[CaseWhen]] expression. This has the following SQL form:
+ * {{{
+ * CASE [expression]
+ * WHEN [value] THEN [expression]
+ * ...
+ * ELSE [expression]
+ * END
+ * }}}
+ */
+ override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.value)
+ val branches = ctx.whenClause.asScala.map { wCtx =>
+ (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result))
+ }
+ CaseWhen(branches, Option(ctx.elseExpression).map(expression))
+ }
+
+ /**
+ * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax:
+ * {{{
+ * CASE
+ * WHEN [predicate] THEN [expression]
+ * ...
+ * ELSE [expression]
+ * END
+ * }}}
+ *
+ * @param ctx the parse tree
+ * */
+ override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) {
+ val branches = ctx.whenClause.asScala.map { wCtx =>
+ (expression(wCtx.condition), expression(wCtx.result))
+ }
+ CaseWhen(branches, Option(ctx.elseExpression).map(expression))
+ }
+
+ /**
+ * Currently only regex in expressions of SELECT statements are supported; in other
+ * places, e.g., where `(a)?+.+` = 2, regex are not meaningful.
+ */
+ private def canApplyRegex(ctx: ParserRuleContext): Boolean = withOrigin(ctx) {
+ var parent = ctx.getParent
+ var result = false
+ while (parent != null) {
+ if (parent.isInstanceOf[NamedExpressionContext]) {
+ result = true
+ }
+ parent = parent.getParent
+ }
+ result
+ }
+
+ /**
+ * Create a dereference expression. The return type depends on the type of the parent.
+ * If the parent is an [[UnresolvedAttribute]], it can be a [[UnresolvedAttribute]] or
+ * a [[UnresolvedRegex]] for regex quoted in ``; if the parent is some other expression,
+ * it can be [[UnresolvedExtractValue]].
+ */
+ override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) {
+ val attr = ctx.fieldName.getText
+ expression(ctx.base) match {
+ case unresolved_attr @ UnresolvedAttribute(nameParts) =>
+ ctx.fieldName.getStart.getText match {
+ case escapedIdentifier(columnNameRegex)
+ if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) =>
+ UnresolvedRegex(columnNameRegex, Some(unresolved_attr.name),
+ conf.caseSensitiveAnalysis)
+ case _ =>
+ UnresolvedAttribute(nameParts :+ attr)
+ }
+ case e =>
+ UnresolvedExtractValue(e, Literal(attr))
+ }
+ }
+
+ /**
+ * Create an [[UnresolvedAttribute]] expression or a [[UnresolvedRegex]] if it is a regex
+ * quoted in ``
+ */
+ override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) {
+ ctx.getStart.getText match {
+ case escapedIdentifier(columnNameRegex)
+ if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) =>
+ UnresolvedRegex(columnNameRegex, None, conf.caseSensitiveAnalysis)
+ case _ =>
+ UnresolvedAttribute.quoted(ctx.getText)
+ }
+
+ }
+
+ /**
+ * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array.
+ */
+ override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) {
+ UnresolvedExtractValue(expression(ctx.value), expression(ctx.index))
+ }
+
+ /**
+ * Create an expression for an expression between parentheses. This is need because the ANTLR
+ * visitor cannot automatically convert the nested context into an expression.
+ */
+ override def visitParenthesizedExpression(
+ ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) {
+ expression(ctx.expression)
+ }
+
+ /**
+ * Create a [[SortOrder]] expression.
+ */
+ override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) {
+ val direction = if (ctx.DESC != null) {
+ Descending
+ } else {
+ Ascending
+ }
+ val nullOrdering = if (ctx.FIRST != null) {
+ NullsFirst
+ } else if (ctx.LAST != null) {
+ NullsLast
+ } else {
+ direction.defaultNullOrdering
+ }
+ SortOrder(expression(ctx.expression), direction, nullOrdering, Set.empty)
+ }
+
+ /**
+ * Create a typed Literal expression. A typed literal has the following SQL syntax:
+ * {{{
+ * [TYPE] '[VALUE]'
+ * }}}
+ * Currently Date, Timestamp and Binary typed literals are supported.
+ */
+ override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) {
+ val value = string(ctx.STRING)
+ val valueType = ctx.identifier.getText.toUpperCase(Locale.ROOT)
+ try {
+ valueType match {
+ case "DATE" =>
+ Literal(Date.valueOf(value))
+ case "TIMESTAMP" =>
+ Literal(Timestamp.valueOf(value))
+ case "X" =>
+ val padding = if (value.length % 2 != 0) "0" else ""
+ Literal(DatatypeConverter.parseHexBinary(padding + value))
+ case other =>
+ throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx)
+ }
+ } catch {
+ case e: IllegalArgumentException =>
+ val message = Option(e.getMessage).getOrElse(s"Exception parsing $valueType")
+ throw new ParseException(message, ctx)
+ }
+ }
+
+ /**
+ * Create a NULL literal expression.
+ */
+ override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) {
+ Literal(null)
+ }
+
+ /**
+ * Create a Boolean literal expression.
+ */
+ override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) {
+ if (ctx.getText.toBoolean) {
+ Literal.TrueLiteral
+ } else {
+ Literal.FalseLiteral
+ }
+ }
+
+ /**
+ * Create an integral literal expression. The code selects the most narrow integral type
+ * possible, either a BigDecimal, a Long or an Integer is returned.
+ */
+ override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) {
+ BigDecimal(ctx.getText) match {
+ case v if v.isValidInt =>
+ Literal(v.intValue())
+ case v if v.isValidLong =>
+ Literal(v.longValue())
+ case v => Literal(v.underlying())
+ }
+ }
+
+ /**
+ * Create a decimal literal for a regular decimal number.
+ */
+ override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) {
+ Literal(BigDecimal(ctx.getText).underlying())
+ }
+
+ /** Create a numeric literal expression. */
+ private def numericLiteral
+ (ctx: NumberContext, minValue: BigDecimal, maxValue: BigDecimal, typeName: String)
+ (converter: String => Any): Literal = withOrigin(ctx) {
+ val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1)
+ try {
+ val rawBigDecimal = BigDecimal(rawStrippedQualifier)
+ if (rawBigDecimal < minValue || rawBigDecimal > maxValue) {
+ throw new ParseException(s"Numeric literal ${rawStrippedQualifier} does not " +
+ s"fit in range [${minValue}, ${maxValue}] for type ${typeName}", ctx)
+ }
+ Literal(converter(rawStrippedQualifier))
+ } catch {
+ case e: NumberFormatException =>
+ throw new ParseException(e.getMessage, ctx)
+ }
+ }
+
+ /**
+ * Create a Byte Literal expression.
+ */
+ override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = {
+ numericLiteral(ctx, Byte.MinValue, Byte.MaxValue, ByteType.simpleString)(_.toByte)
+ }
+
+ /**
+ * Create a Short Literal expression.
+ */
+ override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = {
+ numericLiteral(ctx, Short.MinValue, Short.MaxValue, ShortType.simpleString)(_.toShort)
+ }
+
+ /**
+ * Create a Long Literal expression.
+ */
+ override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = {
+ numericLiteral(ctx, Long.MinValue, Long.MaxValue, LongType.simpleString)(_.toLong)
+ }
+
+ /**
+ * Create a Double Literal expression.
+ */
+ override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = {
+ numericLiteral(ctx, Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble)
+ }
+
+ /**
+ * Create a BigDecimal Literal expression.
+ */
+ override def visitBigDecimalLiteral(ctx: BigDecimalLiteralContext): Literal = {
+ val raw = ctx.getText.substring(0, ctx.getText.length - 2)
+ try {
+ Literal(BigDecimal(raw).underlying())
+ } catch {
+ case e: AnalysisException =>
+ throw new ParseException(e.message, ctx)
+ }
+ }
+
+ /**
+ * Create a String literal expression.
+ */
+ override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) {
+ Literal(createString(ctx))
+ }
+
+ /**
+ * Create a String from a string literal context. This supports multiple consecutive string
+ * literals, these are concatenated, for example this expression "'hello' 'world'" will be
+ * converted into "helloworld".
+ *
+ * Special characters can be escaped by using Hive/C-style escaping.
+ */
+ private def createString(ctx: StringLiteralContext): String = {
+ if (conf.escapedStringLiterals) {
+ ctx.STRING().asScala.map(stringWithoutUnescape).mkString
+ } else {
+ ctx.STRING().asScala.map(string).mkString
+ }
+ }
+
+ /**
+ * Create a [[CalendarInterval]] literal expression. An interval expression can contain multiple
+ * unit value pairs, for instance: interval 2 months 2 days.
+ */
+ override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) {
+ val intervals = ctx.intervalField.asScala.map(visitIntervalField)
+ validate(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx)
+ Literal(intervals.reduce(_.add(_)))
+ }
+
+ /**
+ * Create a [[CalendarInterval]] for a unit value pair. Two unit configuration types are
+ * supported:
+ * - Single unit.
+ * - From-To unit (only 'YEAR TO MONTH' and 'DAY TO SECOND' are supported).
+ */
+ override def visitIntervalField(ctx: IntervalFieldContext): CalendarInterval = withOrigin(ctx) {
+ import ctx._
+ val s = value.getText
+ try {
+ val unitText = unit.getText.toLowerCase(Locale.ROOT)
+ val interval = (unitText, Option(to).map(_.getText.toLowerCase(Locale.ROOT))) match {
+ case (u, None) if u.endsWith("s") =>
+ // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/...
+ CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s)
+ case (u, None) =>
+ CalendarInterval.fromSingleUnitString(u, s)
+ case ("year", Some("month")) =>
+ CalendarInterval.fromYearMonthString(s)
+ case ("day", Some("second")) =>
+ CalendarInterval.fromDayTimeString(s)
+ case (from, Some(t)) =>
+ throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx)
+ }
+ validate(interval != null, "No interval can be constructed", ctx)
+ interval
+ } catch {
+ // Handle Exceptions thrown by CalendarInterval
+ case e: IllegalArgumentException =>
+ val pe = new ParseException(e.getMessage, ctx)
+ pe.setStackTrace(e.getStackTrace)
+ throw pe
+ }
+ }
+
+ /* ********************************************************************************************
+ * DataType parsing
+ * ******************************************************************************************** */
+ /**
+ * Create a Spark DataType.
+ */
+ private def visitSparkDataType(ctx: DataTypeContext): DataType = {
+ HiveStringType.replaceCharType(typedVisit(ctx))
+ }
+
+ /**
+ * Resolve/create a primitive type.
+ */
+ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) {
+ val dataType = ctx.identifier.getText.toLowerCase(Locale.ROOT)
+ (dataType, ctx.INTEGER_VALUE().asScala.toList) match {
+ case ("boolean", Nil) => BooleanType
+ case ("tinyint" | "byte", Nil) => ByteType
+ case ("smallint" | "short", Nil) => ShortType
+ case ("int" | "integer", Nil) => IntegerType
+ case ("bigint" | "long", Nil) => LongType
+ case ("float", Nil) => FloatType
+ case ("double", Nil) => DoubleType
+ case ("date", Nil) => DateType
+ case ("timestamp", Nil) => TimestampType
+ case ("string", Nil) => StringType
+ case ("char", length :: Nil) => CharType(length.getText.toInt)
+ case ("varchar", length :: Nil) => VarcharType(length.getText.toInt)
+ case ("binary", Nil) => BinaryType
+ case ("decimal", Nil) => DecimalType.USER_DEFAULT
+ case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0)
+ case ("decimal", precision :: scale :: Nil) =>
+ DecimalType(precision.getText.toInt, scale.getText.toInt)
+ case (dt, params) =>
+ val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt
+ throw new ParseException(s"DataType $dtStr is not supported.", ctx)
+ }
+ }
+
+ /**
+ * Create a complex DataType. Arrays, Maps and Structures are supported.
+ */
+ override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) {
+ ctx.complex.getType match {
+ case HoodieSqlBaseParser.ARRAY =>
+ ArrayType(typedVisit(ctx.dataType(0)))
+ case HoodieSqlBaseParser.MAP =>
+ MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1)))
+ case HoodieSqlBaseParser.STRUCT =>
+ StructType(Option(ctx.complexColTypeList).toSeq.flatMap(visitComplexColTypeList))
+ }
+ }
+
+ /**
+ * Create top level table schema.
+ */
+ protected def createSchema(ctx: ColTypeListContext): StructType = {
+ StructType(Option(ctx).toSeq.flatMap(visitColTypeList))
+ }
+
+ /**
+ * Create a [[StructType]] from a number of column definitions.
+ */
+ override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) {
+ ctx.colType().asScala.map(visitColType)
+ }
+
+ /**
+ * Create a top level [[StructField]] from a column definition.
+ */
+ override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) {
+ import ctx._
+
+ val builder = new MetadataBuilder
+ // Add comment to metadata
+ if (STRING != null) {
+ builder.putString("comment", string(STRING))
+ }
+ // Add Hive type string to metadata.
+ val rawDataType = typedVisit[DataType](ctx.dataType)
+ val cleanedDataType = HiveStringType.replaceCharType(rawDataType)
+ if (rawDataType != cleanedDataType) {
+ builder.putString(HIVE_TYPE_STRING, rawDataType.catalogString)
+ }
+
+ StructField(
+ identifier.getText,
+ cleanedDataType,
+ nullable = true,
+ builder.build())
+ }
+
+ /**
+ * Create a [[StructType]] from a sequence of [[StructField]]s.
+ */
+ protected def createStructType(ctx: ComplexColTypeListContext): StructType = {
+ StructType(Option(ctx).toSeq.flatMap(visitComplexColTypeList))
+ }
+
+ /**
+ * Create a [[StructType]] from a number of column definitions.
+ */
+ override def visitComplexColTypeList(
+ ctx: ComplexColTypeListContext): Seq[StructField] = withOrigin(ctx) {
+ ctx.complexColType().asScala.map(visitComplexColType)
+ }
+
+ /**
+ * Create a [[StructField]] from a column definition.
+ */
+ override def visitComplexColType(ctx: ComplexColTypeContext): StructField = withOrigin(ctx) {
+ import ctx._
+ val structField = StructField(identifier.getText, typedVisit(dataType), nullable = true)
+ if (STRING == null) structField else structField.withComment(string(STRING))
}
}
diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/unsafe/types/CalendarInterval.java b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/unsafe/types/CalendarInterval.java
new file mode 100644
index 0000000000000..92a7e412ad329
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/unsafe/types/CalendarInterval.java
@@ -0,0 +1,360 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hudi.parser.unsafe.types;
+
+import java.io.Serializable;
+import java.util.Locale;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+/**
+ * The internal representation of interval type.
+ */
+public final class CalendarInterval implements Serializable {
+ public static final long MICROS_PER_MILLI = 1000L;
+ public static final long MICROS_PER_SECOND = MICROS_PER_MILLI * 1000;
+ public static final long MICROS_PER_MINUTE = MICROS_PER_SECOND * 60;
+ public static final long MICROS_PER_HOUR = MICROS_PER_MINUTE * 60;
+ public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24;
+ public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7;
+
+ /**
+ * A function to generate regex which matches interval string's unit part like "3 years".
+ *
+ * First, we can leave out some units in interval string, and we only care about the value of
+ * unit, so here we use non-capturing group to wrap the actual regex.
+ * At the beginning of the actual regex, we should match spaces before the unit part.
+ * Next is the number part, starts with an optional "-" to represent negative value. We use
+ * capturing group to wrap this part as we need the value later.
+ * Finally is the unit name, ends with an optional "s".
+ */
+ private static String unitRegex(String unit) {
+ return "(?:\\s+(-?\\d+)\\s+" + unit + "s?)?";
+ }
+
+ private static Pattern p = Pattern.compile("interval" + unitRegex("year") + unitRegex("month") +
+ unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") +
+ unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond"));
+
+ private static Pattern yearMonthPattern =
+ Pattern.compile("^(?:['|\"])?([+|-])?(\\d+)-(\\d+)(?:['|\"])?$");
+
+ private static Pattern dayTimePattern =
+ Pattern.compile("^(?:['|\"])?([+|-])?(\\d+) (\\d+):(\\d+):(\\d+)(\\.(\\d+))?(?:['|\"])?$");
+
+ private static Pattern quoteTrimPattern = Pattern.compile("^(?:['|\"])?(.*?)(?:['|\"])?$");
+
+ private static long toLong(String s) {
+ if (s == null) {
+ return 0;
+ } else {
+ return Long.parseLong(s);
+ }
+ }
+
+ /**
+ * Convert a string to CalendarInterval. Return null if the input string is not a valid interval.
+ * This method is case-sensitive and all characters in the input string should be in lower case.
+ */
+ public static CalendarInterval fromString(String s) {
+ if (s == null) {
+ return null;
+ }
+ s = s.trim();
+ Matcher m = p.matcher(s);
+ if (!m.matches() || s.equals("interval")) {
+ return null;
+ } else {
+ long months = toLong(m.group(1)) * 12 + toLong(m.group(2));
+ long microseconds = toLong(m.group(3)) * MICROS_PER_WEEK;
+ microseconds += toLong(m.group(4)) * MICROS_PER_DAY;
+ microseconds += toLong(m.group(5)) * MICROS_PER_HOUR;
+ microseconds += toLong(m.group(6)) * MICROS_PER_MINUTE;
+ microseconds += toLong(m.group(7)) * MICROS_PER_SECOND;
+ microseconds += toLong(m.group(8)) * MICROS_PER_MILLI;
+ microseconds += toLong(m.group(9));
+ return new CalendarInterval((int) months, microseconds);
+ }
+ }
+
+ /**
+ * Convert a string to CalendarInterval. Unlike fromString, this method is case-insensitive and
+ * will throw IllegalArgumentException when the input string is not a valid interval.
+ *
+ * @throws IllegalArgumentException if the string is not a valid internal.
+ */
+ public static CalendarInterval fromCaseInsensitiveString(String s) {
+ if (s == null || s.trim().isEmpty()) {
+ throw new IllegalArgumentException("Interval cannot be null or blank.");
+ }
+ String sInLowerCase = s.trim().toLowerCase(Locale.ROOT);
+ String interval =
+ sInLowerCase.startsWith("interval ") ? sInLowerCase : "interval " + sInLowerCase;
+ CalendarInterval cal = fromString(interval);
+ if (cal == null) {
+ throw new IllegalArgumentException("Invalid interval: " + s);
+ }
+ return cal;
+ }
+
+ public static long toLongWithRange(String fieldName,
+ String s, long minValue, long maxValue) throws IllegalArgumentException {
+ long result = 0;
+ if (s != null) {
+ result = Long.parseLong(s);
+ if (result < minValue || result > maxValue) {
+ throw new IllegalArgumentException(String.format("%s %d outside range [%d, %d]",
+ fieldName, result, minValue, maxValue));
+ }
+ }
+ return result;
+ }
+
+ /**
+ * Parse YearMonth string in form: [-]YYYY-MM
+ *
+ * adapted from HiveIntervalYearMonth.valueOf
+ */
+ public static CalendarInterval fromYearMonthString(String s) throws IllegalArgumentException {
+ CalendarInterval result = null;
+ if (s == null) {
+ throw new IllegalArgumentException("Interval year-month string was null");
+ }
+ s = s.trim();
+ Matcher m = yearMonthPattern.matcher(s);
+ if (!m.matches()) {
+ throw new IllegalArgumentException(
+ "Interval string does not match year-month format of 'y-m': " + s);
+ } else {
+ try {
+ int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1;
+ int years = (int) toLongWithRange("year", m.group(2), 0, Integer.MAX_VALUE);
+ int months = (int) toLongWithRange("month", m.group(3), 0, 11);
+ result = new CalendarInterval(sign * (years * 12 + months), 0);
+ } catch (Exception e) {
+ throw new IllegalArgumentException(
+ "Error parsing interval year-month string: " + e.getMessage(), e);
+ }
+ }
+ return result;
+ }
+
+ /**
+ * Parse dayTime string in form: [-]d HH:mm:ss.nnnnnnnnn
+ *
+ * adapted from HiveIntervalDayTime.valueOf
+ */
+ public static CalendarInterval fromDayTimeString(String s) throws IllegalArgumentException {
+ CalendarInterval result = null;
+ if (s == null) {
+ throw new IllegalArgumentException("Interval day-time string was null");
+ }
+ s = s.trim();
+ Matcher m = dayTimePattern.matcher(s);
+ if (!m.matches()) {
+ throw new IllegalArgumentException(
+ "Interval string does not match day-time format of 'd h:m:s.n': " + s);
+ } else {
+ try {
+ int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1;
+ long days = toLongWithRange("day", m.group(2), 0, Integer.MAX_VALUE);
+ long hours = toLongWithRange("hour", m.group(3), 0, 23);
+ long minutes = toLongWithRange("minute", m.group(4), 0, 59);
+ long seconds = toLongWithRange("second", m.group(5), 0, 59);
+ // Hive allow nanosecond precision interval
+ String nanoStr = m.group(7) == null ? null : (m.group(7) + "000000000").substring(0, 9);
+ long nanos = toLongWithRange("nanosecond", nanoStr, 0L, 999999999L);
+ result = new CalendarInterval(0, sign * (
+ days * MICROS_PER_DAY + hours * MICROS_PER_HOUR + minutes * MICROS_PER_MINUTE +
+ seconds * MICROS_PER_SECOND + nanos / 1000L));
+ } catch (Exception e) {
+ throw new IllegalArgumentException(
+ "Error parsing interval day-time string: " + e.getMessage(), e);
+ }
+ }
+ return result;
+ }
+
+ public static CalendarInterval fromSingleUnitString(String unit, String s)
+ throws IllegalArgumentException {
+
+ CalendarInterval result = null;
+ if (s == null) {
+ throw new IllegalArgumentException(String.format("Interval %s string was null", unit));
+ }
+ s = s.trim();
+ Matcher m = quoteTrimPattern.matcher(s);
+ if (!m.matches()) {
+ throw new IllegalArgumentException(
+ "Interval string does not match day-time format of 'd h:m:s.n': " + s);
+ } else {
+ try {
+ switch (unit) {
+ case "year":
+ int year = (int) toLongWithRange("year", m.group(1),
+ Integer.MIN_VALUE / 12, Integer.MAX_VALUE / 12);
+ result = new CalendarInterval(year * 12, 0L);
+ break;
+ case "month":
+ int month = (int) toLongWithRange("month", m.group(1),
+ Integer.MIN_VALUE, Integer.MAX_VALUE);
+ result = new CalendarInterval(month, 0L);
+ break;
+ case "week":
+ long week = toLongWithRange("week", m.group(1),
+ Long.MIN_VALUE / MICROS_PER_WEEK, Long.MAX_VALUE / MICROS_PER_WEEK);
+ result = new CalendarInterval(0, week * MICROS_PER_WEEK);
+ break;
+ case "day":
+ long day = toLongWithRange("day", m.group(1),
+ Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY);
+ result = new CalendarInterval(0, day * MICROS_PER_DAY);
+ break;
+ case "hour":
+ long hour = toLongWithRange("hour", m.group(1),
+ Long.MIN_VALUE / MICROS_PER_HOUR, Long.MAX_VALUE / MICROS_PER_HOUR);
+ result = new CalendarInterval(0, hour * MICROS_PER_HOUR);
+ break;
+ case "minute":
+ long minute = toLongWithRange("minute", m.group(1),
+ Long.MIN_VALUE / MICROS_PER_MINUTE, Long.MAX_VALUE / MICROS_PER_MINUTE);
+ result = new CalendarInterval(0, minute * MICROS_PER_MINUTE);
+ break;
+ case "second": {
+ long micros = parseSecondNano(m.group(1));
+ result = new CalendarInterval(0, micros);
+ break;
+ }
+ case "millisecond":
+ long millisecond = toLongWithRange("millisecond", m.group(1),
+ Long.MIN_VALUE / MICROS_PER_MILLI, Long.MAX_VALUE / MICROS_PER_MILLI);
+ result = new CalendarInterval(0, millisecond * MICROS_PER_MILLI);
+ break;
+ case "microsecond": {
+ long micros = Long.parseLong(m.group(1));
+ result = new CalendarInterval(0, micros);
+ break;
+ }
+ }
+ } catch (Exception e) {
+ throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e);
+ }
+ }
+ return result;
+ }
+
+ /**
+ * Parse second_nano string in ss.nnnnnnnnn format to microseconds
+ */
+ public static long parseSecondNano(String secondNano) throws IllegalArgumentException {
+ String[] parts = secondNano.split("\\.");
+ if (parts.length == 1) {
+ return toLongWithRange("second", parts[0], Long.MIN_VALUE / MICROS_PER_SECOND,
+ Long.MAX_VALUE / MICROS_PER_SECOND) * MICROS_PER_SECOND;
+
+ } else if (parts.length == 2) {
+ long seconds = parts[0].equals("") ? 0L : toLongWithRange("second", parts[0],
+ Long.MIN_VALUE / MICROS_PER_SECOND, Long.MAX_VALUE / MICROS_PER_SECOND);
+ long nanos = toLongWithRange("nanosecond", parts[1], 0L, 999999999L);
+ return seconds * MICROS_PER_SECOND + nanos / 1000L;
+
+ } else {
+ throw new IllegalArgumentException(
+ "Interval string does not match second-nano format of ss.nnnnnnnnn");
+ }
+ }
+
+ public final int months;
+ public final long microseconds;
+
+ public long milliseconds() {
+ return this.microseconds / MICROS_PER_MILLI;
+ }
+
+ public CalendarInterval(int months, long microseconds) {
+ this.months = months;
+ this.microseconds = microseconds;
+ }
+
+ public CalendarInterval add(CalendarInterval that) {
+ int months = this.months + that.months;
+ long microseconds = this.microseconds + that.microseconds;
+ return new CalendarInterval(months, microseconds);
+ }
+
+ public CalendarInterval subtract(CalendarInterval that) {
+ int months = this.months - that.months;
+ long microseconds = this.microseconds - that.microseconds;
+ return new CalendarInterval(months, microseconds);
+ }
+
+ public CalendarInterval negate() {
+ return new CalendarInterval(-this.months, -this.microseconds);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) return true;
+ if (other == null || !(other instanceof CalendarInterval)) return false;
+
+ CalendarInterval o = (CalendarInterval) other;
+ return this.months == o.months && this.microseconds == o.microseconds;
+ }
+
+ @Override
+ public int hashCode() {
+ return 31 * months + (int) microseconds;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder("interval");
+
+ if (months != 0) {
+ appendUnit(sb, months / 12, "year");
+ appendUnit(sb, months % 12, "month");
+ }
+
+ if (microseconds != 0) {
+ long rest = microseconds;
+ appendUnit(sb, rest / MICROS_PER_WEEK, "week");
+ rest %= MICROS_PER_WEEK;
+ appendUnit(sb, rest / MICROS_PER_DAY, "day");
+ rest %= MICROS_PER_DAY;
+ appendUnit(sb, rest / MICROS_PER_HOUR, "hour");
+ rest %= MICROS_PER_HOUR;
+ appendUnit(sb, rest / MICROS_PER_MINUTE, "minute");
+ rest %= MICROS_PER_MINUTE;
+ appendUnit(sb, rest / MICROS_PER_SECOND, "second");
+ rest %= MICROS_PER_SECOND;
+ appendUnit(sb, rest / MICROS_PER_MILLI, "millisecond");
+ rest %= MICROS_PER_MILLI;
+ appendUnit(sb, rest, "microsecond");
+ } else if (months == 0) {
+ sb.append(" 0 microseconds");
+ }
+
+ return sb.toString();
+ }
+
+ private void appendUnit(StringBuilder sb, long value, String unit) {
+ if (value != 0) {
+ sb.append(' ').append(value).append(' ').append(unit).append('s');
+ }
+ }
+}
diff --git a/hudi-spark-datasource/hudi-spark3-common/pom.xml b/hudi-spark-datasource/hudi-spark3-common/pom.xml
index affa987372963..20faa6744178f 100644
--- a/hudi-spark-datasource/hudi-spark3-common/pom.xml
+++ b/hudi-spark-datasource/hudi-spark3-common/pom.xml
@@ -154,6 +154,42 @@
org.jacoco
jacoco-maven-plugin
+
+ org.antlr
+ antlr4-maven-plugin
+ ${antlr.version}
+
+
+
+ antlr4
+
+
+
+
+ true
+ true
+ ../hudi-spark3-common/src/main/antlr4/
+ ../hudi-spark3-common/src/main/antlr4/imports
+
+
+
+ org.antlr
+ antlr4-maven-plugin
+ ${antlr.version}
+
+
+
+ antlr4
+
+
+
+
+ true
+ true
+ ../hudi-spark3-common/src/main/antlr4/
+ ../hudi-spark3-common/src/main/antlr4/imports
+
+
@@ -244,4 +280,4 @@
-
\ No newline at end of file
+
diff --git a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/Spark3Adapter.scala b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/Spark3Adapter.scala
index a1c41e80ab362..ca1ad1afdf58a 100644
--- a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/Spark3Adapter.scala
+++ b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/Spark3Adapter.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.adapter
import org.apache.hudi.Spark3RowSerDe
import org.apache.hudi.client.utils.SparkRowSerDe
import org.apache.hudi.spark3.internal.ReflectUtil
-import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Expression, Like}
@@ -31,11 +30,10 @@ import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
-import org.apache.spark.sql.execution.datasources.{FilePartition, LogicalRelation, PartitionedFile, Spark3ParsePartitionUtil, SparkParsePartitionUtil}
+import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.hudi.SparkAdapter
import org.apache.spark.sql.internal.SQLConf
-
-import scala.collection.JavaConverters.mapAsScalaMapConverter
+import org.apache.spark.sql.{Row, SparkSession}
/**
* The adapter for spark3.
diff --git a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeTravelRelation.scala b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeTravelRelation.scala
new file mode 100644
index 0000000000000..67f942881c39d
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeTravelRelation.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+
+case class TimeTravelRelation(
+ table: LogicalPlan,
+ timestamp: Option[Expression],
+ version: Option[String]) extends Command {
+ override def children: Seq[LogicalPlan] = Seq.empty
+
+ override def output: Seq[Attribute] = Nil
+
+ override lazy val resolved: Boolean = false
+
+ def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = this
+}
diff --git a/hudi-spark-datasource/hudi-spark3.1.x/pom.xml b/hudi-spark-datasource/hudi-spark3.1.x/pom.xml
index f6d9f7d557216..374b1722a365d 100644
--- a/hudi-spark-datasource/hudi-spark3.1.x/pom.xml
+++ b/hudi-spark-datasource/hudi-spark3.1.x/pom.xml
@@ -144,6 +144,24 @@
org.jacoco
jacoco-maven-plugin
+
+ org.antlr
+ antlr4-maven-plugin
+ ${antlr.version}
+
+
+
+ antlr4
+
+
+
+
+ true
+ true
+ ../hudi-spark3.1.x/src/main/antlr4
+ ../hudi-spark3.1.x/src/main/antlr4/imports
+
+
@@ -157,7 +175,7 @@
org.apache.spark
spark-sql_2.12
- ${spark3.version}
+ ${spark3.1.version}
true
diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/antlr4/imports/SqlBase.g4 b/hudi-spark-datasource/hudi-spark3.1.x/src/main/antlr4/imports/SqlBase.g4
new file mode 100644
index 0000000000000..500f66c9e53f9
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/antlr4/imports/SqlBase.g4
@@ -0,0 +1,1852 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar.
+ */
+
+// The parser file is forked from spark 3.1.2's SqlBase.g4.
+grammar SqlBase;
+
+@parser::members {
+ /**
+ * When false, INTERSECT is given the greater precedence over the other set
+ * operations (UNION, EXCEPT and MINUS) as per the SQL standard.
+ */
+ public boolean legacy_setops_precedence_enbled = false;
+
+ /**
+ * When false, a literal with an exponent would be converted into
+ * double type rather than decimal type.
+ */
+ public boolean legacy_exponent_literal_as_decimal_enabled = false;
+
+ /**
+ * When true, the behavior of keywords follows ANSI SQL standard.
+ */
+ public boolean SQL_standard_keyword_behavior = false;
+}
+
+@lexer::members {
+ /**
+ * Verify whether current token is a valid decimal token (which contains dot).
+ * Returns true if the character that follows the token is not a digit or letter or underscore.
+ *
+ * For example:
+ * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'.
+ * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'.
+ * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'.
+ * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed
+ * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+'
+ * which is not a digit or letter or underscore.
+ */
+ public boolean isValidDecimal() {
+ int nextChar = _input.LA(1);
+ if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' ||
+ nextChar == '_') {
+ return false;
+ } else {
+ return true;
+ }
+ }
+
+ /**
+ * This method will be called when we see '/*' and try to match it as a bracketed comment.
+ * If the next character is '+', it should be parsed as hint later, and we cannot match
+ * it as a bracketed comment.
+ *
+ * Returns true if the next character is '+'.
+ */
+ public boolean isHint() {
+ int nextChar = _input.LA(1);
+ if (nextChar == '+') {
+ return true;
+ } else {
+ return false;
+ }
+ }
+}
+
+singleStatement
+ : statement ';'* EOF
+ ;
+
+singleExpression
+ : namedExpression EOF
+ ;
+
+singleTableIdentifier
+ : tableIdentifier EOF
+ ;
+
+singleMultipartIdentifier
+ : multipartIdentifier EOF
+ ;
+
+singleFunctionIdentifier
+ : functionIdentifier EOF
+ ;
+
+singleDataType
+ : dataType EOF
+ ;
+
+singleTableSchema
+ : colTypeList EOF
+ ;
+
+statement
+ : query #statementDefault
+ | ctes? dmlStatementNoWith #dmlStatement
+ | USE NAMESPACE? multipartIdentifier #use
+ | CREATE namespace (IF NOT EXISTS)? multipartIdentifier
+ (commentSpec |
+ locationSpec |
+ (WITH (DBPROPERTIES | PROPERTIES) tablePropertyList))* #createNamespace
+ | ALTER namespace multipartIdentifier
+ SET (DBPROPERTIES | PROPERTIES) tablePropertyList #setNamespaceProperties
+ | ALTER namespace multipartIdentifier
+ SET locationSpec #setNamespaceLocation
+ | DROP namespace (IF EXISTS)? multipartIdentifier
+ (RESTRICT | CASCADE)? #dropNamespace
+ | SHOW (DATABASES | NAMESPACES) ((FROM | IN) multipartIdentifier)?
+ (LIKE? pattern=STRING)? #showNamespaces
+ | createTableHeader ('(' colTypeList ')')? tableProvider?
+ createTableClauses
+ (AS? query)? #createTable
+ | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier
+ LIKE source=tableIdentifier
+ (tableProvider |
+ rowFormat |
+ createFileFormat |
+ locationSpec |
+ (TBLPROPERTIES tableProps=tablePropertyList))* #createTableLike
+ | replaceTableHeader ('(' colTypeList ')')? tableProvider?
+ createTableClauses
+ (AS? query)? #replaceTable
+ | ANALYZE TABLE multipartIdentifier partitionSpec? COMPUTE STATISTICS
+ (identifier | FOR COLUMNS identifierSeq | FOR ALL COLUMNS)? #analyze
+ | ALTER TABLE multipartIdentifier
+ ADD (COLUMN | COLUMNS)
+ columns=qualifiedColTypeWithPositionList #addTableColumns
+ | ALTER TABLE multipartIdentifier
+ ADD (COLUMN | COLUMNS)
+ '(' columns=qualifiedColTypeWithPositionList ')' #addTableColumns
+ | ALTER TABLE table=multipartIdentifier
+ RENAME COLUMN
+ from=multipartIdentifier TO to=errorCapturingIdentifier #renameTableColumn
+ | ALTER TABLE multipartIdentifier
+ DROP (COLUMN | COLUMNS)
+ '(' columns=multipartIdentifierList ')' #dropTableColumns
+ | ALTER TABLE multipartIdentifier
+ DROP (COLUMN | COLUMNS) columns=multipartIdentifierList #dropTableColumns
+ | ALTER (TABLE | VIEW) from=multipartIdentifier
+ RENAME TO to=multipartIdentifier #renameTable
+ | ALTER (TABLE | VIEW) multipartIdentifier
+ SET TBLPROPERTIES tablePropertyList #setTableProperties
+ | ALTER (TABLE | VIEW) multipartIdentifier
+ UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties
+ | ALTER TABLE table=multipartIdentifier
+ (ALTER | CHANGE) COLUMN? column=multipartIdentifier
+ alterColumnAction? #alterTableAlterColumn
+ | ALTER TABLE table=multipartIdentifier partitionSpec?
+ CHANGE COLUMN?
+ colName=multipartIdentifier colType colPosition? #hiveChangeColumn
+ | ALTER TABLE table=multipartIdentifier partitionSpec?
+ REPLACE COLUMNS
+ '(' columns=qualifiedColTypeWithPositionList ')' #hiveReplaceColumns
+ | ALTER TABLE multipartIdentifier (partitionSpec)?
+ SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe
+ | ALTER TABLE multipartIdentifier (partitionSpec)?
+ SET SERDEPROPERTIES tablePropertyList #setTableSerDe
+ | ALTER (TABLE | VIEW) multipartIdentifier ADD (IF NOT EXISTS)?
+ partitionSpecLocation+ #addTablePartition
+ | ALTER TABLE multipartIdentifier
+ from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition
+ | ALTER (TABLE | VIEW) multipartIdentifier
+ DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions
+ | ALTER TABLE multipartIdentifier
+ (partitionSpec)? SET locationSpec #setTableLocation
+ | ALTER TABLE multipartIdentifier RECOVER PARTITIONS #recoverPartitions
+ | DROP TABLE (IF EXISTS)? multipartIdentifier PURGE? #dropTable
+ | DROP VIEW (IF EXISTS)? multipartIdentifier #dropView
+ | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)?
+ VIEW (IF NOT EXISTS)? multipartIdentifier
+ identifierCommentList?
+ (commentSpec |
+ (PARTITIONED ON identifierList) |
+ (TBLPROPERTIES tablePropertyList))*
+ AS query #createView
+ | CREATE (OR REPLACE)? GLOBAL? TEMPORARY VIEW
+ tableIdentifier ('(' colTypeList ')')? tableProvider
+ (OPTIONS tablePropertyList)? #createTempViewUsing
+ | ALTER VIEW multipartIdentifier AS? query #alterViewQuery
+ | CREATE (OR REPLACE)? TEMPORARY? FUNCTION (IF NOT EXISTS)?
+ multipartIdentifier AS className=STRING
+ (USING resource (',' resource)*)? #createFunction
+ | DROP TEMPORARY? FUNCTION (IF EXISTS)? multipartIdentifier #dropFunction
+ | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)?
+ statement #explain
+ | SHOW TABLES ((FROM | IN) multipartIdentifier)?
+ (LIKE? pattern=STRING)? #showTables
+ | SHOW TABLE EXTENDED ((FROM | IN) ns=multipartIdentifier)?
+ LIKE pattern=STRING partitionSpec? #showTable
+ | SHOW TBLPROPERTIES table=multipartIdentifier
+ ('(' key=tablePropertyKey ')')? #showTblProperties
+ | SHOW COLUMNS (FROM | IN) table=multipartIdentifier
+ ((FROM | IN) ns=multipartIdentifier)? #showColumns
+ | SHOW VIEWS ((FROM | IN) multipartIdentifier)?
+ (LIKE? pattern=STRING)? #showViews
+ | SHOW PARTITIONS multipartIdentifier partitionSpec? #showPartitions
+ | SHOW identifier? FUNCTIONS
+ (LIKE? (multipartIdentifier | pattern=STRING))? #showFunctions
+ | SHOW CREATE TABLE multipartIdentifier (AS SERDE)? #showCreateTable
+ | SHOW CURRENT NAMESPACE #showCurrentNamespace
+ | (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction
+ | (DESC | DESCRIBE) namespace EXTENDED?
+ multipartIdentifier #describeNamespace
+ | (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)?
+ multipartIdentifier partitionSpec? describeColName? #describeRelation
+ | (DESC | DESCRIBE) QUERY? query #describeQuery
+ | COMMENT ON namespace multipartIdentifier IS
+ comment=(STRING | NULL) #commentNamespace
+ | COMMENT ON TABLE multipartIdentifier IS comment=(STRING | NULL) #commentTable
+ | REFRESH TABLE multipartIdentifier #refreshTable
+ | REFRESH FUNCTION multipartIdentifier #refreshFunction
+ | REFRESH (STRING | .*?) #refreshResource
+ | CACHE LAZY? TABLE multipartIdentifier
+ (OPTIONS options=tablePropertyList)? (AS? query)? #cacheTable
+ | UNCACHE TABLE (IF EXISTS)? multipartIdentifier #uncacheTable
+ | CLEAR CACHE #clearCache
+ | LOAD DATA LOCAL? INPATH path=STRING OVERWRITE? INTO TABLE
+ multipartIdentifier partitionSpec? #loadData
+ | TRUNCATE TABLE multipartIdentifier partitionSpec? #truncateTable
+ | MSCK REPAIR TABLE multipartIdentifier #repairTable
+ | op=(ADD | LIST) identifier (STRING | .*?) #manageResource
+ | SET ROLE .*? #failNativeCommand
+ | SET TIME ZONE interval #setTimeZone
+ | SET TIME ZONE timezone=(STRING | LOCAL) #setTimeZone
+ | SET TIME ZONE .*? #setTimeZone
+ | SET configKey EQ configValue #setQuotedConfiguration
+ | SET configKey (EQ .*?)? #setQuotedConfiguration
+ | SET .*? EQ configValue #setQuotedConfiguration
+ | SET .*? #setConfiguration
+ | RESET configKey #resetQuotedConfiguration
+ | RESET .*? #resetConfiguration
+ | unsupportedHiveNativeCommands .*? #failNativeCommand
+ ;
+
+configKey
+ : quotedIdentifier
+ ;
+
+configValue
+ : quotedIdentifier
+ ;
+
+unsupportedHiveNativeCommands
+ : kw1=CREATE kw2=ROLE
+ | kw1=DROP kw2=ROLE
+ | kw1=GRANT kw2=ROLE?
+ | kw1=REVOKE kw2=ROLE?
+ | kw1=SHOW kw2=GRANT
+ | kw1=SHOW kw2=ROLE kw3=GRANT?
+ | kw1=SHOW kw2=PRINCIPALS
+ | kw1=SHOW kw2=ROLES
+ | kw1=SHOW kw2=CURRENT kw3=ROLES
+ | kw1=EXPORT kw2=TABLE
+ | kw1=IMPORT kw2=TABLE
+ | kw1=SHOW kw2=COMPACTIONS
+ | kw1=SHOW kw2=CREATE kw3=TABLE
+ | kw1=SHOW kw2=TRANSACTIONS
+ | kw1=SHOW kw2=INDEXES
+ | kw1=SHOW kw2=LOCKS
+ | kw1=CREATE kw2=INDEX
+ | kw1=DROP kw2=INDEX
+ | kw1=ALTER kw2=INDEX
+ | kw1=LOCK kw2=TABLE
+ | kw1=LOCK kw2=DATABASE
+ | kw1=UNLOCK kw2=TABLE
+ | kw1=UNLOCK kw2=DATABASE
+ | kw1=CREATE kw2=TEMPORARY kw3=MACRO
+ | kw1=DROP kw2=TEMPORARY kw3=MACRO
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=CLUSTERED
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=CLUSTERED kw4=BY
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SORTED
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=SKEWED kw4=BY
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SKEWED
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=STORED kw5=AS kw6=DIRECTORIES
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=SET kw4=SKEWED kw5=LOCATION
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=EXCHANGE kw4=PARTITION
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=ARCHIVE kw4=PARTITION
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=UNARCHIVE kw4=PARTITION
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=TOUCH
+ | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=COMPACT
+ | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=CONCATENATE
+ | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=SET kw4=FILEFORMAT
+ | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=REPLACE kw4=COLUMNS
+ | kw1=START kw2=TRANSACTION
+ | kw1=COMMIT
+ | kw1=ROLLBACK
+ | kw1=DFS
+ ;
+
+createTableHeader
+ : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? multipartIdentifier
+ ;
+
+replaceTableHeader
+ : (CREATE OR)? REPLACE TABLE multipartIdentifier
+ ;
+
+bucketSpec
+ : CLUSTERED BY identifierList
+ (SORTED BY orderedIdentifierList)?
+ INTO INTEGER_VALUE BUCKETS
+ ;
+
+skewSpec
+ : SKEWED BY identifierList
+ ON (constantList | nestedConstantList)
+ (STORED AS DIRECTORIES)?
+ ;
+
+locationSpec
+ : LOCATION STRING
+ ;
+
+commentSpec
+ : COMMENT STRING
+ ;
+
+query
+ : ctes? queryTerm queryOrganization
+ ;
+
+insertInto
+ : INSERT OVERWRITE TABLE? multipartIdentifier (partitionSpec (IF NOT EXISTS)?)? identifierList? #insertOverwriteTable
+ | INSERT INTO TABLE? multipartIdentifier partitionSpec? (IF NOT EXISTS)? identifierList? #insertIntoTable
+ | INSERT OVERWRITE LOCAL? DIRECTORY path=STRING rowFormat? createFileFormat? #insertOverwriteHiveDir
+ | INSERT OVERWRITE LOCAL? DIRECTORY (path=STRING)? tableProvider (OPTIONS options=tablePropertyList)? #insertOverwriteDir
+ ;
+
+partitionSpecLocation
+ : partitionSpec locationSpec?
+ ;
+
+partitionSpec
+ : PARTITION '(' partitionVal (',' partitionVal)* ')'
+ ;
+
+partitionVal
+ : identifier (EQ constant)?
+ ;
+
+namespace
+ : NAMESPACE
+ | DATABASE
+ | SCHEMA
+ ;
+
+describeFuncName
+ : qualifiedName
+ | STRING
+ | comparisonOperator
+ | arithmeticOperator
+ | predicateOperator
+ ;
+
+describeColName
+ : nameParts+=identifier ('.' nameParts+=identifier)*
+ ;
+
+ctes
+ : WITH namedQuery (',' namedQuery)*
+ ;
+
+namedQuery
+ : name=errorCapturingIdentifier (columnAliases=identifierList)? AS? '(' query ')'
+ ;
+
+tableProvider
+ : USING multipartIdentifier
+ ;
+
+createTableClauses
+ :((OPTIONS options=tablePropertyList) |
+ (PARTITIONED BY partitioning=partitionFieldList) |
+ skewSpec |
+ bucketSpec |
+ rowFormat |
+ createFileFormat |
+ locationSpec |
+ commentSpec |
+ (TBLPROPERTIES tableProps=tablePropertyList))*
+ ;
+
+tablePropertyList
+ : '(' tableProperty (',' tableProperty)* ')'
+ ;
+
+tableProperty
+ : key=tablePropertyKey (EQ? value=tablePropertyValue)?
+ ;
+
+tablePropertyKey
+ : identifier ('.' identifier)*
+ | STRING
+ ;
+
+tablePropertyValue
+ : INTEGER_VALUE
+ | DECIMAL_VALUE
+ | booleanValue
+ | STRING
+ ;
+
+constantList
+ : '(' constant (',' constant)* ')'
+ ;
+
+nestedConstantList
+ : '(' constantList (',' constantList)* ')'
+ ;
+
+createFileFormat
+ : STORED AS fileFormat
+ | STORED BY storageHandler
+ ;
+
+fileFormat
+ : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING #tableFileFormat
+ | identifier #genericFileFormat
+ ;
+
+storageHandler
+ : STRING (WITH SERDEPROPERTIES tablePropertyList)?
+ ;
+
+resource
+ : identifier STRING
+ ;
+
+dmlStatementNoWith
+ : insertInto queryTerm queryOrganization #singleInsertQuery
+ | fromClause multiInsertQueryBody+ #multiInsertQuery
+ | DELETE FROM multipartIdentifier tableAlias whereClause? #deleteFromTable
+ | UPDATE multipartIdentifier tableAlias setClause whereClause? #updateTable
+ | MERGE INTO target=multipartIdentifier targetAlias=tableAlias
+ USING (source=multipartIdentifier |
+ '(' sourceQuery=query')') sourceAlias=tableAlias
+ ON mergeCondition=booleanExpression
+ matchedClause*
+ notMatchedClause* #mergeIntoTable
+ ;
+
+queryOrganization
+ : (ORDER BY order+=sortItem (',' order+=sortItem)*)?
+ (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)?
+ (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)?
+ (SORT BY sort+=sortItem (',' sort+=sortItem)*)?
+ windowClause?
+ (LIMIT (ALL | limit=expression))?
+ ;
+
+multiInsertQueryBody
+ : insertInto fromStatementBody
+ ;
+
+queryTerm
+ : queryPrimary #queryTermDefault
+ | left=queryTerm {legacy_setops_precedence_enbled}?
+ operator=(INTERSECT | UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation
+ | left=queryTerm {!legacy_setops_precedence_enbled}?
+ operator=INTERSECT setQuantifier? right=queryTerm #setOperation
+ | left=queryTerm {!legacy_setops_precedence_enbled}?
+ operator=(UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation
+ ;
+
+queryPrimary
+ : querySpecification #queryPrimaryDefault
+ | fromStatement #fromStmt
+ | TABLE multipartIdentifier #table
+ | inlineTable #inlineTableDefault1
+ | '(' query ')' #subquery
+ ;
+
+sortItem
+ : expression ordering=(ASC | DESC)? (NULLS nullOrder=(LAST | FIRST))?
+ ;
+
+fromStatement
+ : fromClause fromStatementBody+
+ ;
+
+fromStatementBody
+ : transformClause
+ whereClause?
+ queryOrganization
+ | selectClause
+ lateralView*
+ whereClause?
+ aggregationClause?
+ havingClause?
+ windowClause?
+ queryOrganization
+ ;
+
+querySpecification
+ : transformClause
+ fromClause?
+ whereClause? #transformQuerySpecification
+ | selectClause
+ fromClause?
+ lateralView*
+ whereClause?
+ aggregationClause?
+ havingClause?
+ windowClause? #regularQuerySpecification
+ ;
+
+transformClause
+ : (SELECT kind=TRANSFORM '(' namedExpressionSeq ')'
+ | kind=MAP namedExpressionSeq
+ | kind=REDUCE namedExpressionSeq)
+ inRowFormat=rowFormat?
+ (RECORDWRITER recordWriter=STRING)?
+ USING script=STRING
+ (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))?
+ outRowFormat=rowFormat?
+ (RECORDREADER recordReader=STRING)?
+ ;
+
+selectClause
+ : SELECT (hints+=hint)* setQuantifier? namedExpressionSeq
+ ;
+
+setClause
+ : SET assignmentList
+ ;
+
+matchedClause
+ : WHEN MATCHED (AND matchedCond=booleanExpression)? THEN matchedAction
+ ;
+notMatchedClause
+ : WHEN NOT MATCHED (AND notMatchedCond=booleanExpression)? THEN notMatchedAction
+ ;
+
+matchedAction
+ : DELETE
+ | UPDATE SET ASTERISK
+ | UPDATE SET assignmentList
+ ;
+
+notMatchedAction
+ : INSERT ASTERISK
+ | INSERT '(' columns=multipartIdentifierList ')'
+ VALUES '(' expression (',' expression)* ')'
+ ;
+
+assignmentList
+ : assignment (',' assignment)*
+ ;
+
+assignment
+ : key=multipartIdentifier EQ value=expression
+ ;
+
+whereClause
+ : WHERE booleanExpression
+ ;
+
+havingClause
+ : HAVING booleanExpression
+ ;
+
+hint
+ : '/*+' hintStatements+=hintStatement (','? hintStatements+=hintStatement)* '*/'
+ ;
+
+hintStatement
+ : hintName=identifier
+ | hintName=identifier '(' parameters+=primaryExpression (',' parameters+=primaryExpression)* ')'
+ ;
+
+fromClause
+ : FROM relation (',' relation)* lateralView* pivotClause?
+ ;
+
+temporalClause
+ : FOR? (SYSTEM_TIME | TIMESTAMP) AS OF timestamp=valueExpression
+ | FOR? (SYSTEM_VERSION | VERSION) AS OF version=(INTEGER_VALUE | STRING)
+ ;
+
+aggregationClause
+ : GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* (
+ WITH kind=ROLLUP
+ | WITH kind=CUBE
+ | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')?
+ | GROUP BY kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')'
+ ;
+
+groupingSet
+ : '(' (expression (',' expression)*)? ')'
+ | expression
+ ;
+
+pivotClause
+ : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn IN '(' pivotValues+=pivotValue (',' pivotValues+=pivotValue)* ')' ')'
+ ;
+
+pivotColumn
+ : identifiers+=identifier
+ | '(' identifiers+=identifier (',' identifiers+=identifier)* ')'
+ ;
+
+pivotValue
+ : expression (AS? identifier)?
+ ;
+
+lateralView
+ : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)?
+ ;
+
+setQuantifier
+ : DISTINCT
+ | ALL
+ ;
+
+relation
+ : relationPrimary joinRelation*
+ ;
+
+joinRelation
+ : (joinType) JOIN right=relationPrimary joinCriteria?
+ | NATURAL joinType JOIN right=relationPrimary
+ ;
+
+joinType
+ : INNER?
+ | CROSS
+ | LEFT OUTER?
+ | LEFT? SEMI
+ | RIGHT OUTER?
+ | FULL OUTER?
+ | LEFT? ANTI
+ ;
+
+joinCriteria
+ : ON booleanExpression
+ | USING identifierList
+ ;
+
+sample
+ : TABLESAMPLE '(' sampleMethod? ')'
+ ;
+
+sampleMethod
+ : negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) PERCENTLIT #sampleByPercentile
+ | expression ROWS #sampleByRows
+ | sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE
+ (ON (identifier | qualifiedName '(' ')'))? #sampleByBucket
+ | bytes=expression #sampleByBytes
+ ;
+
+identifierList
+ : '(' identifierSeq ')'
+ ;
+
+identifierSeq
+ : ident+=errorCapturingIdentifier (',' ident+=errorCapturingIdentifier)*
+ ;
+
+orderedIdentifierList
+ : '(' orderedIdentifier (',' orderedIdentifier)* ')'
+ ;
+
+orderedIdentifier
+ : ident=errorCapturingIdentifier ordering=(ASC | DESC)?
+ ;
+
+identifierCommentList
+ : '(' identifierComment (',' identifierComment)* ')'
+ ;
+
+identifierComment
+ : identifier commentSpec?
+ ;
+
+relationPrimary
+ : multipartIdentifier temporalClause?
+ sample? tableAlias #tableName
+ | '(' query ')' sample? tableAlias #aliasedQuery
+ | '(' relation ')' sample? tableAlias #aliasedRelation
+ | inlineTable #inlineTableDefault2
+ | functionTable #tableValuedFunction
+ ;
+
+inlineTable
+ : VALUES expression (',' expression)* tableAlias
+ ;
+
+functionTable
+ : funcName=errorCapturingIdentifier '(' (expression (',' expression)*)? ')' tableAlias
+ ;
+
+tableAlias
+ : (AS? strictIdentifier identifierList?)?
+ ;
+
+rowFormat
+ : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde
+ | ROW FORMAT DELIMITED
+ (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)?
+ (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)?
+ (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)?
+ (LINES TERMINATED BY linesSeparatedBy=STRING)?
+ (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited
+ ;
+
+multipartIdentifierList
+ : multipartIdentifier (',' multipartIdentifier)*
+ ;
+
+multipartIdentifier
+ : parts+=errorCapturingIdentifier ('.' parts+=errorCapturingIdentifier)*
+ ;
+
+tableIdentifier
+ : (db=errorCapturingIdentifier '.')? table=errorCapturingIdentifier
+ ;
+
+functionIdentifier
+ : (db=errorCapturingIdentifier '.')? function=errorCapturingIdentifier
+ ;
+
+namedExpression
+ : expression (AS? (name=errorCapturingIdentifier | identifierList))?
+ ;
+
+namedExpressionSeq
+ : namedExpression (',' namedExpression)*
+ ;
+
+partitionFieldList
+ : '(' fields+=partitionField (',' fields+=partitionField)* ')'
+ ;
+
+partitionField
+ : transform #partitionTransform
+ | colType #partitionColumn
+ ;
+
+transform
+ : qualifiedName #identityTransform
+ | transformName=identifier
+ '(' argument+=transformArgument (',' argument+=transformArgument)* ')' #applyTransform
+ ;
+
+transformArgument
+ : qualifiedName
+ | constant
+ ;
+
+expression
+ : booleanExpression
+ ;
+
+booleanExpression
+ : NOT booleanExpression #logicalNot
+ | EXISTS '(' query ')' #exists
+ | valueExpression predicate? #predicated
+ | left=booleanExpression operator=AND right=booleanExpression #logicalBinary
+ | left=booleanExpression operator=OR right=booleanExpression #logicalBinary
+ ;
+
+predicate
+ : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression
+ | NOT? kind=IN '(' expression (',' expression)* ')'
+ | NOT? kind=IN '(' query ')'
+ | NOT? kind=RLIKE pattern=valueExpression
+ | NOT? kind=LIKE quantifier=(ANY | SOME | ALL) ('('')' | '(' expression (',' expression)* ')')
+ | NOT? kind=LIKE pattern=valueExpression (ESCAPE escapeChar=STRING)?
+ | IS NOT? kind=NULL
+ | IS NOT? kind=(TRUE | FALSE | UNKNOWN)
+ | IS NOT? kind=DISTINCT FROM right=valueExpression
+ ;
+
+valueExpression
+ : primaryExpression #valueExpressionDefault
+ | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary
+ | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=(PLUS | MINUS | CONCAT_PIPE) right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary
+ | left=valueExpression comparisonOperator right=valueExpression #comparison
+ ;
+
+primaryExpression
+ : name=(CURRENT_DATE | CURRENT_TIMESTAMP) #currentDatetime
+ | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
+ | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
+ | CAST '(' expression AS dataType ')' #cast
+ | STRUCT '(' (argument+=namedExpression (',' argument+=namedExpression)*)? ')' #struct
+ | FIRST '(' expression (IGNORE NULLS)? ')' #first
+ | LAST '(' expression (IGNORE NULLS)? ')' #last
+ | POSITION '(' substr=valueExpression IN str=valueExpression ')' #position
+ | constant #constantDefault
+ | ASTERISK #star
+ | qualifiedName '.' ASTERISK #star
+ | '(' namedExpression (',' namedExpression)+ ')' #rowConstructor
+ | '(' query ')' #subqueryExpression
+ | functionName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')'
+ (FILTER '(' WHERE where=booleanExpression ')')? (OVER windowSpec)? #functionCall
+ | identifier '->' expression #lambda
+ | '(' identifier (',' identifier)+ ')' '->' expression #lambda
+ | value=primaryExpression '[' index=valueExpression ']' #subscript
+ | identifier #columnReference
+ | base=primaryExpression '.' fieldName=identifier #dereference
+ | '(' expression ')' #parenthesizedExpression
+ | EXTRACT '(' field=identifier FROM source=valueExpression ')' #extract
+ | (SUBSTR | SUBSTRING) '(' str=valueExpression (FROM | ',') pos=valueExpression
+ ((FOR | ',') len=valueExpression)? ')' #substring
+ | TRIM '(' trimOption=(BOTH | LEADING | TRAILING)? (trimStr=valueExpression)?
+ FROM srcStr=valueExpression ')' #trim
+ | OVERLAY '(' input=valueExpression PLACING replace=valueExpression
+ FROM position=valueExpression (FOR length=valueExpression)? ')' #overlay
+ ;
+
+constant
+ : NULL #nullLiteral
+ | interval #intervalLiteral
+ | identifier STRING #typeConstructor
+ | number #numericLiteral
+ | booleanValue #booleanLiteral
+ | STRING+ #stringLiteral
+ ;
+
+comparisonOperator
+ : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ
+ ;
+
+arithmeticOperator
+ : PLUS | MINUS | ASTERISK | SLASH | PERCENT | DIV | TILDE | AMPERSAND | PIPE | CONCAT_PIPE | HAT
+ ;
+
+predicateOperator
+ : OR | AND | IN | NOT
+ ;
+
+booleanValue
+ : TRUE | FALSE
+ ;
+
+interval
+ : INTERVAL (errorCapturingMultiUnitsInterval | errorCapturingUnitToUnitInterval)?
+ ;
+
+errorCapturingMultiUnitsInterval
+ : multiUnitsInterval unitToUnitInterval?
+ ;
+
+multiUnitsInterval
+ : (intervalValue unit+=identifier)+
+ ;
+
+errorCapturingUnitToUnitInterval
+ : body=unitToUnitInterval (error1=multiUnitsInterval | error2=unitToUnitInterval)?
+ ;
+
+unitToUnitInterval
+ : value=intervalValue from=identifier TO to=identifier
+ ;
+
+intervalValue
+ : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE)
+ | STRING
+ ;
+
+colPosition
+ : position=FIRST | position=AFTER afterCol=errorCapturingIdentifier
+ ;
+
+dataType
+ : complex=ARRAY '<' dataType '>' #complexDataType
+ | complex=MAP '<' dataType ',' dataType '>' #complexDataType
+ | complex=STRUCT ('<' complexColTypeList? '>' | NEQ) #complexDataType
+ | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType
+ ;
+
+qualifiedColTypeWithPositionList
+ : qualifiedColTypeWithPosition (',' qualifiedColTypeWithPosition)*
+ ;
+
+qualifiedColTypeWithPosition
+ : name=multipartIdentifier dataType (NOT NULL)? commentSpec? colPosition?
+ ;
+
+colTypeList
+ : colType (',' colType)*
+ ;
+
+colType
+ : colName=errorCapturingIdentifier dataType (NOT NULL)? commentSpec?
+ ;
+
+complexColTypeList
+ : complexColType (',' complexColType)*
+ ;
+
+complexColType
+ : identifier ':' dataType (NOT NULL)? commentSpec?
+ ;
+
+whenClause
+ : WHEN condition=expression THEN result=expression
+ ;
+
+windowClause
+ : WINDOW namedWindow (',' namedWindow)*
+ ;
+
+namedWindow
+ : name=errorCapturingIdentifier AS windowSpec
+ ;
+
+windowSpec
+ : name=errorCapturingIdentifier #windowRef
+ | '('name=errorCapturingIdentifier')' #windowRef
+ | '('
+ ( CLUSTER BY partition+=expression (',' partition+=expression)*
+ | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)?
+ ((ORDER | SORT) BY sortItem (',' sortItem)*)?)
+ windowFrame?
+ ')' #windowDef
+ ;
+
+windowFrame
+ : frameType=RANGE start=frameBound
+ | frameType=ROWS start=frameBound
+ | frameType=RANGE BETWEEN start=frameBound AND end=frameBound
+ | frameType=ROWS BETWEEN start=frameBound AND end=frameBound
+ ;
+
+frameBound
+ : UNBOUNDED boundType=(PRECEDING | FOLLOWING)
+ | boundType=CURRENT ROW
+ | expression boundType=(PRECEDING | FOLLOWING)
+ ;
+
+qualifiedNameList
+ : qualifiedName (',' qualifiedName)*
+ ;
+
+functionName
+ : qualifiedName
+ | FILTER
+ | LEFT
+ | RIGHT
+ ;
+
+qualifiedName
+ : identifier ('.' identifier)*
+ ;
+
+// this rule is used for explicitly capturing wrong identifiers such as test-table, which should actually be `test-table`
+// replace identifier with errorCapturingIdentifier where the immediate follow symbol is not an expression, otherwise
+// valid expressions such as "a-b" can be recognized as an identifier
+errorCapturingIdentifier
+ : identifier errorCapturingIdentifierExtra
+ ;
+
+// extra left-factoring grammar
+errorCapturingIdentifierExtra
+ : (MINUS identifier)+ #errorIdent
+ | #realIdent
+ ;
+
+identifier
+ : strictIdentifier
+ | {!SQL_standard_keyword_behavior}? strictNonReserved
+ ;
+
+strictIdentifier
+ : IDENTIFIER #unquotedIdentifier
+ | quotedIdentifier #quotedIdentifierAlternative
+ | {SQL_standard_keyword_behavior}? ansiNonReserved #unquotedIdentifier
+ | {!SQL_standard_keyword_behavior}? nonReserved #unquotedIdentifier
+ ;
+
+quotedIdentifier
+ : BACKQUOTED_IDENTIFIER
+ ;
+
+number
+ : {!legacy_exponent_literal_as_decimal_enabled}? MINUS? EXPONENT_VALUE #exponentLiteral
+ | {!legacy_exponent_literal_as_decimal_enabled}? MINUS? DECIMAL_VALUE #decimalLiteral
+ | {legacy_exponent_literal_as_decimal_enabled}? MINUS? (EXPONENT_VALUE | DECIMAL_VALUE) #legacyDecimalLiteral
+ | MINUS? INTEGER_VALUE #integerLiteral
+ | MINUS? BIGINT_LITERAL #bigIntLiteral
+ | MINUS? SMALLINT_LITERAL #smallIntLiteral
+ | MINUS? TINYINT_LITERAL #tinyIntLiteral
+ | MINUS? DOUBLE_LITERAL #doubleLiteral
+ | MINUS? FLOAT_LITERAL #floatLiteral
+ | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral
+ ;
+
+alterColumnAction
+ : TYPE dataType
+ | commentSpec
+ | colPosition
+ | setOrDrop=(SET | DROP) NOT NULL
+ ;
+
+// When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL.
+// - Reserved keywords:
+// Keywords that are reserved and can't be used as identifiers for table, view, column,
+// function, alias, etc.
+// - Non-reserved keywords:
+// Keywords that have a special meaning only in particular contexts and can be used as
+// identifiers in other contexts. For example, `EXPLAIN SELECT ...` is a command, but EXPLAIN
+// can be used as identifiers in other places.
+// You can find the full keywords list by searching "Start of the keywords list" in this file.
+// The non-reserved keywords are listed below. Keywords not in this list are reserved keywords.
+ansiNonReserved
+//--ANSI-NON-RESERVED-START
+ : ADD
+ | AFTER
+ | ALTER
+ | ANALYZE
+ | ANTI
+ | ARCHIVE
+ | ARRAY
+ | ASC
+ | AT
+ | BETWEEN
+ | BUCKET
+ | BUCKETS
+ | BY
+ | CACHE
+ | CASCADE
+ | CHANGE
+ | CLEAR
+ | CLUSTER
+ | CLUSTERED
+ | CODEGEN
+ | COLLECTION
+ | COLUMNS
+ | COMMENT
+ | COMMIT
+ | COMPACT
+ | COMPACTIONS
+ | COMPUTE
+ | CONCATENATE
+ | COST
+ | CUBE
+ | CURRENT
+ | DATA
+ | DATABASE
+ | DATABASES
+ | DBPROPERTIES
+ | DEFINED
+ | DELETE
+ | DELIMITED
+ | DESC
+ | DESCRIBE
+ | DFS
+ | DIRECTORIES
+ | DIRECTORY
+ | DISTRIBUTE
+ | DIV
+ | DROP
+ | ESCAPED
+ | EXCHANGE
+ | EXISTS
+ | EXPLAIN
+ | EXPORT
+ | EXTENDED
+ | EXTERNAL
+ | EXTRACT
+ | FIELDS
+ | FILEFORMAT
+ | FIRST
+ | FOLLOWING
+ | FORMAT
+ | FORMATTED
+ | FUNCTION
+ | FUNCTIONS
+ | GLOBAL
+ | GROUPING
+ | IF
+ | IGNORE
+ | IMPORT
+ | INDEX
+ | INDEXES
+ | INPATH
+ | INPUTFORMAT
+ | INSERT
+ | INTERVAL
+ | ITEMS
+ | KEYS
+ | LAST
+ | LATERAL
+ | LAZY
+ | LIKE
+ | LIMIT
+ | LINES
+ | LIST
+ | LOAD
+ | LOCAL
+ | LOCATION
+ | LOCK
+ | LOCKS
+ | LOGICAL
+ | MACRO
+ | MAP
+ | MATCHED
+ | MERGE
+ | MSCK
+ | NAMESPACE
+ | NAMESPACES
+ | NO
+ | NULLS
+ | OF
+ | OPTION
+ | OPTIONS
+ | OUT
+ | OUTPUTFORMAT
+ | OVER
+ | OVERLAY
+ | OVERWRITE
+ | PARTITION
+ | PARTITIONED
+ | PARTITIONS
+ | PERCENTLIT
+ | PIVOT
+ | PLACING
+ | POSITION
+ | PRECEDING
+ | PRINCIPALS
+ | PROPERTIES
+ | PURGE
+ | QUERY
+ | RANGE
+ | RECORDREADER
+ | RECORDWRITER
+ | RECOVER
+ | REDUCE
+ | REFRESH
+ | RENAME
+ | REPAIR
+ | REPLACE
+ | RESET
+ | RESTRICT
+ | REVOKE
+ | RLIKE
+ | ROLE
+ | ROLES
+ | ROLLBACK
+ | ROLLUP
+ | ROW
+ | ROWS
+ | SCHEMA
+ | SEMI
+ | SEPARATED
+ | SERDE
+ | SERDEPROPERTIES
+ | SET
+ | SETMINUS
+ | SETS
+ | SHOW
+ | SKEWED
+ | SORT
+ | SORTED
+ | START
+ | STATISTICS
+ | STORED
+ | STRATIFY
+ | STRUCT
+ | SUBSTR
+ | SUBSTRING
+ | TABLES
+ | TABLESAMPLE
+ | TBLPROPERTIES
+ | TEMPORARY
+ | TERMINATED
+ | TOUCH
+ | TRANSACTION
+ | TRANSACTIONS
+ | TRANSFORM
+ | TRIM
+ | TRUE
+ | TRUNCATE
+ | TYPE
+ | UNARCHIVE
+ | UNBOUNDED
+ | UNCACHE
+ | UNLOCK
+ | UNSET
+ | UPDATE
+ | USE
+ | VALUES
+ | VIEW
+ | VIEWS
+ | WINDOW
+ | ZONE
+//--ANSI-NON-RESERVED-END
+ ;
+
+// When `SQL_standard_keyword_behavior=false`, there are 2 kinds of keywords in Spark SQL.
+// - Non-reserved keywords:
+// Same definition as the one when `SQL_standard_keyword_behavior=true`.
+// - Strict-non-reserved keywords:
+// A strict version of non-reserved keywords, which can not be used as table alias.
+// You can find the full keywords list by searching "Start of the keywords list" in this file.
+// The strict-non-reserved keywords are listed in `strictNonReserved`.
+// The non-reserved keywords are listed in `nonReserved`.
+// These 2 together contain all the keywords.
+strictNonReserved
+ : ANTI
+ | CROSS
+ | EXCEPT
+ | FULL
+ | INNER
+ | INTERSECT
+ | JOIN
+ | LEFT
+ | NATURAL
+ | ON
+ | RIGHT
+ | SEMI
+ | SETMINUS
+ | UNION
+ | USING
+ ;
+
+nonReserved
+//--DEFAULT-NON-RESERVED-START
+ : ADD
+ | AFTER
+ | ALL
+ | ALTER
+ | ANALYZE
+ | AND
+ | ANY
+ | ARCHIVE
+ | ARRAY
+ | AS
+ | ASC
+ | AT
+ | AUTHORIZATION
+ | BETWEEN
+ | BOTH
+ | BUCKET
+ | BUCKETS
+ | BY
+ | CACHE
+ | CASCADE
+ | CASE
+ | CAST
+ | CHANGE
+ | CHECK
+ | CLEAR
+ | CLUSTER
+ | CLUSTERED
+ | CODEGEN
+ | COLLATE
+ | COLLECTION
+ | COLUMN
+ | COLUMNS
+ | COMMENT
+ | COMMIT
+ | COMPACT
+ | COMPACTIONS
+ | COMPUTE
+ | CONCATENATE
+ | CONSTRAINT
+ | COST
+ | CREATE
+ | CUBE
+ | CURRENT
+ | CURRENT_DATE
+ | CURRENT_TIME
+ | CURRENT_TIMESTAMP
+ | CURRENT_USER
+ | DATA
+ | DATABASE
+ | DATABASES
+ | DBPROPERTIES
+ | DEFINED
+ | DELETE
+ | DELIMITED
+ | DESC
+ | DESCRIBE
+ | DFS
+ | DIRECTORIES
+ | DIRECTORY
+ | DISTINCT
+ | DISTRIBUTE
+ | DIV
+ | DROP
+ | ELSE
+ | END
+ | ESCAPE
+ | ESCAPED
+ | EXCHANGE
+ | EXISTS
+ | EXPLAIN
+ | EXPORT
+ | EXTENDED
+ | EXTERNAL
+ | EXTRACT
+ | FALSE
+ | FETCH
+ | FILTER
+ | FIELDS
+ | FILEFORMAT
+ | FIRST
+ | FOLLOWING
+ | FOR
+ | FOREIGN
+ | FORMAT
+ | FORMATTED
+ | FROM
+ | FUNCTION
+ | FUNCTIONS
+ | GLOBAL
+ | GRANT
+ | GROUP
+ | GROUPING
+ | HAVING
+ | IF
+ | IGNORE
+ | IMPORT
+ | IN
+ | INDEX
+ | INDEXES
+ | INPATH
+ | INPUTFORMAT
+ | INSERT
+ | INTERVAL
+ | INTO
+ | IS
+ | ITEMS
+ | KEYS
+ | LAST
+ | LATERAL
+ | LAZY
+ | LEADING
+ | LIKE
+ | LIMIT
+ | LINES
+ | LIST
+ | LOAD
+ | LOCAL
+ | LOCATION
+ | LOCK
+ | LOCKS
+ | LOGICAL
+ | MACRO
+ | MAP
+ | MATCHED
+ | MERGE
+ | MSCK
+ | NAMESPACE
+ | NAMESPACES
+ | NO
+ | NOT
+ | NULL
+ | NULLS
+ | OF
+ | ONLY
+ | OPTION
+ | OPTIONS
+ | OR
+ | ORDER
+ | OUT
+ | OUTER
+ | OUTPUTFORMAT
+ | OVER
+ | OVERLAPS
+ | OVERLAY
+ | OVERWRITE
+ | PARTITION
+ | PARTITIONED
+ | PARTITIONS
+ | PERCENTLIT
+ | PIVOT
+ | PLACING
+ | POSITION
+ | PRECEDING
+ | PRIMARY
+ | PRINCIPALS
+ | PROPERTIES
+ | PURGE
+ | QUERY
+ | RANGE
+ | RECORDREADER
+ | RECORDWRITER
+ | RECOVER
+ | REDUCE
+ | REFERENCES
+ | REFRESH
+ | RENAME
+ | REPAIR
+ | REPLACE
+ | RESET
+ | RESTRICT
+ | REVOKE
+ | RLIKE
+ | ROLE
+ | ROLES
+ | ROLLBACK
+ | ROLLUP
+ | ROW
+ | ROWS
+ | SCHEMA
+ | SELECT
+ | SEPARATED
+ | SERDE
+ | SERDEPROPERTIES
+ | SESSION_USER
+ | SET
+ | SETS
+ | SHOW
+ | SKEWED
+ | SOME
+ | SORT
+ | SORTED
+ | START
+ | STATISTICS
+ | STORED
+ | STRATIFY
+ | STRUCT
+ | SUBSTR
+ | SUBSTRING
+ | TABLE
+ | TABLES
+ | TABLESAMPLE
+ | TBLPROPERTIES
+ | TEMPORARY
+ | TERMINATED
+ | THEN
+ | TIME
+ | TO
+ | TOUCH
+ | TRAILING
+ | TRANSACTION
+ | TRANSACTIONS
+ | TRANSFORM
+ | TRIM
+ | TRUE
+ | TRUNCATE
+ | TYPE
+ | UNARCHIVE
+ | UNBOUNDED
+ | UNCACHE
+ | UNIQUE
+ | UNKNOWN
+ | UNLOCK
+ | UNSET
+ | UPDATE
+ | USE
+ | USER
+ | VALUES
+ | VIEW
+ | VIEWS
+ | WHEN
+ | WHERE
+ | WINDOW
+ | WITH
+ | ZONE
+ | SYSTEM_VERSION
+ | VERSION
+ | SYSTEM_TIME
+ | TIMESTAMP
+//--DEFAULT-NON-RESERVED-END
+ ;
+
+// NOTE: If you add a new token in the list below, you should update the list of keywords
+// and reserved tag in `docs/sql-ref-ansi-compliance.md#sql-keywords`.
+
+//============================
+// Start of the keywords list
+//============================
+//--SPARK-KEYWORD-LIST-START
+ADD: 'ADD';
+AFTER: 'AFTER';
+ALL: 'ALL';
+ALTER: 'ALTER';
+ANALYZE: 'ANALYZE';
+AND: 'AND';
+ANTI: 'ANTI';
+ANY: 'ANY';
+ARCHIVE: 'ARCHIVE';
+ARRAY: 'ARRAY';
+AS: 'AS';
+ASC: 'ASC';
+AT: 'AT';
+AUTHORIZATION: 'AUTHORIZATION';
+BETWEEN: 'BETWEEN';
+BOTH: 'BOTH';
+BUCKET: 'BUCKET';
+BUCKETS: 'BUCKETS';
+BY: 'BY';
+CACHE: 'CACHE';
+CASCADE: 'CASCADE';
+CASE: 'CASE';
+CAST: 'CAST';
+CHANGE: 'CHANGE';
+CHECK: 'CHECK';
+CLEAR: 'CLEAR';
+CLUSTER: 'CLUSTER';
+CLUSTERED: 'CLUSTERED';
+CODEGEN: 'CODEGEN';
+COLLATE: 'COLLATE';
+COLLECTION: 'COLLECTION';
+COLUMN: 'COLUMN';
+COLUMNS: 'COLUMNS';
+COMMENT: 'COMMENT';
+COMMIT: 'COMMIT';
+COMPACT: 'COMPACT';
+COMPACTIONS: 'COMPACTIONS';
+COMPUTE: 'COMPUTE';
+CONCATENATE: 'CONCATENATE';
+CONSTRAINT: 'CONSTRAINT';
+COST: 'COST';
+CREATE: 'CREATE';
+CROSS: 'CROSS';
+CUBE: 'CUBE';
+CURRENT: 'CURRENT';
+CURRENT_DATE: 'CURRENT_DATE';
+CURRENT_TIME: 'CURRENT_TIME';
+CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP';
+CURRENT_USER: 'CURRENT_USER';
+DATA: 'DATA';
+DATABASE: 'DATABASE';
+DATABASES: 'DATABASES' | 'SCHEMAS';
+DBPROPERTIES: 'DBPROPERTIES';
+DEFINED: 'DEFINED';
+DELETE: 'DELETE';
+DELIMITED: 'DELIMITED';
+DESC: 'DESC';
+DESCRIBE: 'DESCRIBE';
+DFS: 'DFS';
+DIRECTORIES: 'DIRECTORIES';
+DIRECTORY: 'DIRECTORY';
+DISTINCT: 'DISTINCT';
+DISTRIBUTE: 'DISTRIBUTE';
+DIV: 'DIV';
+DROP: 'DROP';
+ELSE: 'ELSE';
+END: 'END';
+ESCAPE: 'ESCAPE';
+ESCAPED: 'ESCAPED';
+EXCEPT: 'EXCEPT';
+EXCHANGE: 'EXCHANGE';
+EXISTS: 'EXISTS';
+EXPLAIN: 'EXPLAIN';
+EXPORT: 'EXPORT';
+EXTENDED: 'EXTENDED';
+EXTERNAL: 'EXTERNAL';
+EXTRACT: 'EXTRACT';
+FALSE: 'FALSE';
+FETCH: 'FETCH';
+FIELDS: 'FIELDS';
+FILTER: 'FILTER';
+FILEFORMAT: 'FILEFORMAT';
+FIRST: 'FIRST';
+FOLLOWING: 'FOLLOWING';
+FOR: 'FOR';
+FOREIGN: 'FOREIGN';
+FORMAT: 'FORMAT';
+FORMATTED: 'FORMATTED';
+FROM: 'FROM';
+FULL: 'FULL';
+FUNCTION: 'FUNCTION';
+FUNCTIONS: 'FUNCTIONS';
+GLOBAL: 'GLOBAL';
+GRANT: 'GRANT';
+GROUP: 'GROUP';
+GROUPING: 'GROUPING';
+HAVING: 'HAVING';
+IF: 'IF';
+IGNORE: 'IGNORE';
+IMPORT: 'IMPORT';
+IN: 'IN';
+INDEX: 'INDEX';
+INDEXES: 'INDEXES';
+INNER: 'INNER';
+INPATH: 'INPATH';
+INPUTFORMAT: 'INPUTFORMAT';
+INSERT: 'INSERT';
+INTERSECT: 'INTERSECT';
+INTERVAL: 'INTERVAL';
+INTO: 'INTO';
+IS: 'IS';
+ITEMS: 'ITEMS';
+JOIN: 'JOIN';
+KEYS: 'KEYS';
+LAST: 'LAST';
+LATERAL: 'LATERAL';
+LAZY: 'LAZY';
+LEADING: 'LEADING';
+LEFT: 'LEFT';
+LIKE: 'LIKE';
+LIMIT: 'LIMIT';
+LINES: 'LINES';
+LIST: 'LIST';
+LOAD: 'LOAD';
+LOCAL: 'LOCAL';
+LOCATION: 'LOCATION';
+LOCK: 'LOCK';
+LOCKS: 'LOCKS';
+LOGICAL: 'LOGICAL';
+MACRO: 'MACRO';
+MAP: 'MAP';
+MATCHED: 'MATCHED';
+MERGE: 'MERGE';
+MSCK: 'MSCK';
+NAMESPACE: 'NAMESPACE';
+NAMESPACES: 'NAMESPACES';
+NATURAL: 'NATURAL';
+NO: 'NO';
+NOT: 'NOT' | '!';
+NULL: 'NULL';
+NULLS: 'NULLS';
+OF: 'OF';
+ON: 'ON';
+ONLY: 'ONLY';
+OPTION: 'OPTION';
+OPTIONS: 'OPTIONS';
+OR: 'OR';
+ORDER: 'ORDER';
+OUT: 'OUT';
+OUTER: 'OUTER';
+OUTPUTFORMAT: 'OUTPUTFORMAT';
+OVER: 'OVER';
+OVERLAPS: 'OVERLAPS';
+OVERLAY: 'OVERLAY';
+OVERWRITE: 'OVERWRITE';
+PARTITION: 'PARTITION';
+PARTITIONED: 'PARTITIONED';
+PARTITIONS: 'PARTITIONS';
+PERCENTLIT: 'PERCENT';
+PIVOT: 'PIVOT';
+PLACING: 'PLACING';
+POSITION: 'POSITION';
+PRECEDING: 'PRECEDING';
+PRIMARY: 'PRIMARY';
+PRINCIPALS: 'PRINCIPALS';
+PROPERTIES: 'PROPERTIES';
+PURGE: 'PURGE';
+QUERY: 'QUERY';
+RANGE: 'RANGE';
+RECORDREADER: 'RECORDREADER';
+RECORDWRITER: 'RECORDWRITER';
+RECOVER: 'RECOVER';
+REDUCE: 'REDUCE';
+REFERENCES: 'REFERENCES';
+REFRESH: 'REFRESH';
+RENAME: 'RENAME';
+REPAIR: 'REPAIR';
+REPLACE: 'REPLACE';
+RESET: 'RESET';
+RESTRICT: 'RESTRICT';
+REVOKE: 'REVOKE';
+RIGHT: 'RIGHT';
+RLIKE: 'RLIKE' | 'REGEXP';
+ROLE: 'ROLE';
+ROLES: 'ROLES';
+ROLLBACK: 'ROLLBACK';
+ROLLUP: 'ROLLUP';
+ROW: 'ROW';
+ROWS: 'ROWS';
+SCHEMA: 'SCHEMA';
+SELECT: 'SELECT';
+SEMI: 'SEMI';
+SEPARATED: 'SEPARATED';
+SERDE: 'SERDE';
+SERDEPROPERTIES: 'SERDEPROPERTIES';
+SESSION_USER: 'SESSION_USER';
+SET: 'SET';
+SETMINUS: 'MINUS';
+SETS: 'SETS';
+SHOW: 'SHOW';
+SKEWED: 'SKEWED';
+SOME: 'SOME';
+SORT: 'SORT';
+SORTED: 'SORTED';
+START: 'START';
+STATISTICS: 'STATISTICS';
+STORED: 'STORED';
+STRATIFY: 'STRATIFY';
+STRUCT: 'STRUCT';
+SUBSTR: 'SUBSTR';
+SUBSTRING: 'SUBSTRING';
+TABLE: 'TABLE';
+TABLES: 'TABLES';
+TABLESAMPLE: 'TABLESAMPLE';
+TBLPROPERTIES: 'TBLPROPERTIES';
+TEMPORARY: 'TEMPORARY' | 'TEMP';
+TERMINATED: 'TERMINATED';
+THEN: 'THEN';
+TIME: 'TIME';
+TO: 'TO';
+TOUCH: 'TOUCH';
+TRAILING: 'TRAILING';
+TRANSACTION: 'TRANSACTION';
+TRANSACTIONS: 'TRANSACTIONS';
+TRANSFORM: 'TRANSFORM';
+TRIM: 'TRIM';
+TRUE: 'TRUE';
+TRUNCATE: 'TRUNCATE';
+TYPE: 'TYPE';
+UNARCHIVE: 'UNARCHIVE';
+UNBOUNDED: 'UNBOUNDED';
+UNCACHE: 'UNCACHE';
+UNION: 'UNION';
+UNIQUE: 'UNIQUE';
+UNKNOWN: 'UNKNOWN';
+UNLOCK: 'UNLOCK';
+UNSET: 'UNSET';
+UPDATE: 'UPDATE';
+USE: 'USE';
+USER: 'USER';
+USING: 'USING';
+VALUES: 'VALUES';
+VIEW: 'VIEW';
+VIEWS: 'VIEWS';
+WHEN: 'WHEN';
+WHERE: 'WHERE';
+WINDOW: 'WINDOW';
+WITH: 'WITH';
+ZONE: 'ZONE';
+
+SYSTEM_VERSION: 'SYSTEM_VERSION';
+VERSION: 'VERSION';
+SYSTEM_TIME: 'SYSTEM_TIME';
+TIMESTAMP: 'TIMESTAMP';
+//--SPARK-KEYWORD-LIST-END
+//============================
+// End of the keywords list
+//============================
+
+EQ : '=' | '==';
+NSEQ: '<=>';
+NEQ : '<>';
+NEQJ: '!=';
+LT : '<';
+LTE : '<=' | '!>';
+GT : '>';
+GTE : '>=' | '!<';
+
+PLUS: '+';
+MINUS: '-';
+ASTERISK: '*';
+SLASH: '/';
+PERCENT: '%';
+TILDE: '~';
+AMPERSAND: '&';
+PIPE: '|';
+CONCAT_PIPE: '||';
+HAT: '^';
+
+STRING
+ : '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
+ | '"' ( ~('"'|'\\') | ('\\' .) )* '"'
+ ;
+
+BIGINT_LITERAL
+ : DIGIT+ 'L'
+ ;
+
+SMALLINT_LITERAL
+ : DIGIT+ 'S'
+ ;
+
+TINYINT_LITERAL
+ : DIGIT+ 'Y'
+ ;
+
+INTEGER_VALUE
+ : DIGIT+
+ ;
+
+EXPONENT_VALUE
+ : DIGIT+ EXPONENT
+ | DECIMAL_DIGITS EXPONENT {isValidDecimal()}?
+ ;
+
+DECIMAL_VALUE
+ : DECIMAL_DIGITS {isValidDecimal()}?
+ ;
+
+FLOAT_LITERAL
+ : DIGIT+ EXPONENT? 'F'
+ | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}?
+ ;
+
+DOUBLE_LITERAL
+ : DIGIT+ EXPONENT? 'D'
+ | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}?
+ ;
+
+BIGDECIMAL_LITERAL
+ : DIGIT+ EXPONENT? 'BD'
+ | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}?
+ ;
+
+IDENTIFIER
+ : (LETTER | DIGIT | '_')+
+ ;
+
+BACKQUOTED_IDENTIFIER
+ : '`' ( ~'`' | '``' )* '`'
+ ;
+
+fragment DECIMAL_DIGITS
+ : DIGIT+ '.' DIGIT*
+ | '.' DIGIT+
+ ;
+
+fragment EXPONENT
+ : 'E' [+-]? DIGIT+
+ ;
+
+fragment DIGIT
+ : [0-9]
+ ;
+
+fragment LETTER
+ : [A-Z]
+ ;
+
+SIMPLE_COMMENT
+ : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN)
+ ;
+
+BRACKETED_COMMENT
+ : '/*' {!isHint()}? (BRACKETED_COMMENT|.)*? '*/' -> channel(HIDDEN)
+ ;
+
+WS
+ : [ \r\n\t]+ -> channel(HIDDEN)
+ ;
+
+// Catch-all for anything we can't recognize.
+// We use this to be able to ignore and recover all the text
+// when splitting statements with DelimiterLexer
+UNRECOGNIZED
+ : .
+ ;
diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 b/hudi-spark-datasource/hudi-spark3.1.x/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4
new file mode 100644
index 0000000000000..0ee6bee1970c1
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+grammar HoodieSqlBase;
+
+import SqlBase;
+
+singleStatement
+ : statement EOF
+ ;
+
+statement
+ : query #queryStatement
+ | ctes? dmlStatementNoWith #dmlStatement
+ | createTableHeader ('(' colTypeList ')')? tableProvider?
+ createTableClauses
+ (AS? query)? #createTable
+ | replaceTableHeader ('(' colTypeList ')')? tableProvider?
+ createTableClauses
+ (AS? query)? #replaceTable
+ | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)?
+ VIEW (IF NOT EXISTS)? multipartIdentifier
+ identifierCommentList?
+ (commentSpec |
+ (PARTITIONED ON identifierList) |
+ (TBLPROPERTIES tablePropertyList))*
+ AS query #createView
+ | ALTER VIEW multipartIdentifier AS? query #alterViewQuery
+ | (DESC | DESCRIBE) QUERY? query #describeQuery
+ | CACHE LAZY? TABLE multipartIdentifier
+ (OPTIONS options=tablePropertyList)? (AS? query)? #cacheTable
+ | .*? #passThrough
+ ;
diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala
new file mode 100644
index 0000000000000..94e18d916b9f3
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.adapter
+
+import org.apache.hudi.Spark3RowSerDe
+import org.apache.hudi.client.utils.SparkRowSerDe
+import org.apache.hudi.spark3.internal.ReflectUtil
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.{Expression, Like}
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, Join, JoinHint, LogicalPlan}
+import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile, Spark3ParsePartitionUtil, SparkParsePartitionUtil}
+import org.apache.spark.sql.hudi.SparkAdapter
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.parser.HoodieSpark3_1ExtendedSqlParser
+import org.apache.spark.sql.{Row, SparkSession}
+
+/**
+ * The adapter for spark3.
+ */
+class Spark3_1Adapter extends SparkAdapter {
+
+ override def createSparkRowSerDe(encoder: ExpressionEncoder[Row]): SparkRowSerDe = {
+ new Spark3RowSerDe(encoder)
+ }
+
+ override def toTableIdentifier(aliasId: AliasIdentifier): TableIdentifier = {
+ aliasId match {
+ case AliasIdentifier(name, Seq(database)) =>
+ TableIdentifier(name, Some(database))
+ case AliasIdentifier(name, Seq(_, database)) =>
+ TableIdentifier(name, Some(database))
+ case AliasIdentifier(name, Seq()) =>
+ TableIdentifier(name, None)
+ case _=> throw new IllegalArgumentException(s"Cannot cast $aliasId to TableIdentifier")
+ }
+ }
+
+ override def toTableIdentifier(relation: UnresolvedRelation): TableIdentifier = {
+ relation.multipartIdentifier.asTableIdentifier
+ }
+
+ override def createJoin(left: LogicalPlan, right: LogicalPlan, joinType: JoinType): Join = {
+ Join(left, right, joinType, None, JoinHint.NONE)
+ }
+
+ override def isInsertInto(plan: LogicalPlan): Boolean = {
+ plan.isInstanceOf[InsertIntoStatement]
+ }
+
+ override def getInsertIntoChildren(plan: LogicalPlan):
+ Option[(LogicalPlan, Map[String, Option[String]], LogicalPlan, Boolean, Boolean)] = {
+ plan match {
+ case insert: InsertIntoStatement =>
+ Some((insert.table, insert.partitionSpec, insert.query, insert.overwrite, insert.ifPartitionNotExists))
+ case _ =>
+ None
+ }
+ }
+
+ override def createInsertInto(table: LogicalPlan, partition: Map[String, Option[String]],
+ query: LogicalPlan, overwrite: Boolean, ifPartitionNotExists: Boolean): LogicalPlan = {
+ ReflectUtil.createInsertInto(table, partition, Seq.empty[String], query, overwrite, ifPartitionNotExists)
+ }
+
+ override def createSparkParsePartitionUtil(conf: SQLConf): SparkParsePartitionUtil = {
+ new Spark3ParsePartitionUtil(conf)
+ }
+
+ override def createLike(left: Expression, right: Expression): Expression = {
+ new Like(left, right)
+ }
+
+ override def parseMultipartIdentifier(parser: ParserInterface, sqlText: String): Seq[String] = {
+ parser.parseMultipartIdentifier(sqlText)
+ }
+
+ override def createExtendedSparkParser: Option[(SparkSession, ParserInterface) => ParserInterface] = {
+ Some(
+ (spark: SparkSession, delegate: ParserInterface) => new HoodieSpark3_1ExtendedSqlParser(spark, delegate)
+ )
+ }
+
+ /**
+ * Combine [[PartitionedFile]] to [[FilePartition]] according to `maxSplitBytes`.
+ */
+ override def getFilePartitions(
+ sparkSession: SparkSession,
+ partitionedFiles: Seq[PartitionedFile],
+ maxSplitBytes: Long): Seq[FilePartition] = {
+ FilePartition.getFilePartitions(sparkSession, partitionedFiles, maxSplitBytes)
+ }
+}
diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_1ExtendedSqlAstBuilder.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_1ExtendedSqlAstBuilder.scala
new file mode 100644
index 0000000000000..eefd3cf407d95
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_1ExtendedSqlAstBuilder.scala
@@ -0,0 +1,3152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.parser
+
+import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode}
+import org.antlr.v4.runtime.{ParserRuleContext, Token}
+import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser._
+import org.apache.hudi.spark.sql.parser.{HoodieSqlBaseBaseVisitor, HoodieSqlBaseParser}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, FunctionResource, FunctionResourceType}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
+import org.apache.spark.sql.catalyst.parser.ParserUtils.{EnhancedLogicalPlan, checkDuplicateClauses, checkDuplicateKeys, escapedIdentifier, operationNotAllowed, source, string, stringWithoutUnescape, validate, withOrigin}
+import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
+import org.apache.spark.sql.catalyst.util.IntervalUtils.IntervalUnit
+import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils, truncatedString}
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
+import org.apache.spark.sql.connector.catalog.{SupportsNamespaces, TableCatalog}
+import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform, Expression => V2Expression}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.util.random.RandomSampler
+
+import java.util.Locale
+import javax.xml.bind.DatatypeConverter
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * The AstBuilder for HoodieSqlParser to parser the AST tree to Logical Plan.
+ * Here we only do the parser for the extended sql syntax. e.g MergeInto. For
+ * other sql syntax we use the delegate sql parser which is the SparkSqlParser.
+ */
+class HoodieSpark3_1ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterface)
+ extends HoodieSqlBaseBaseVisitor[AnyRef] with Logging {
+
+ protected def typedVisit[T](ctx: ParseTree): T = {
+ ctx.accept(this).asInstanceOf[T]
+ }
+
+ /**
+ * Override the default behavior for all visit methods. This will only return a non-null result
+ * when the context has only one child. This is done because there is no generic method to
+ * combine the results of the context children. In all other cases null is returned.
+ */
+ override def visitChildren(node: RuleNode): AnyRef = {
+ if (node.getChildCount == 1) {
+ node.getChild(0).accept(this)
+ } else {
+ null
+ }
+ }
+
+ /**
+ * Create an aliased table reference. This is typically used in FROM clauses.
+ */
+ override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) {
+ val tableId = visitMultipartIdentifier(ctx.multipartIdentifier())
+ val relation = UnresolvedRelation(tableId)
+ val table = mayApplyAliasPlan(
+ ctx.tableAlias, relation.optionalMap(ctx.temporalClause)(withTimeTravel))
+ table.optionalMap(ctx.sample)(withSample)
+ }
+
+ private def withTimeTravel(
+ ctx: TemporalClauseContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ val v = ctx.version
+ val version = if (ctx.INTEGER_VALUE != null) {
+ Some(v.getText)
+ } else {
+ Option(v).map(string)
+ }
+
+ val timestamp = Option(ctx.timestamp).map(expression)
+ if (timestamp.exists(_.references.nonEmpty)) {
+ throw new ParseException(
+ "timestamp expression cannot refer to any columns", ctx.timestamp)
+ }
+ if (timestamp.exists(e => SubqueryExpression.hasSubquery(e))) {
+ throw new ParseException(
+ "timestamp expression cannot contain subqueries", ctx.timestamp)
+ }
+
+ TimeTravelRelation(plan, timestamp, version)
+ }
+
+ // ============== The following code is fork from org.apache.spark.sql.catalyst.parser.AstBuilder
+ override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
+ visit(ctx.statement).asInstanceOf[LogicalPlan]
+ }
+
+ override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) {
+ visitNamedExpression(ctx.namedExpression)
+ }
+
+ override def visitSingleTableIdentifier(
+ ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) {
+ visitTableIdentifier(ctx.tableIdentifier)
+ }
+
+ override def visitSingleFunctionIdentifier(
+ ctx: SingleFunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) {
+ visitFunctionIdentifier(ctx.functionIdentifier)
+ }
+
+ override def visitSingleMultipartIdentifier(
+ ctx: SingleMultipartIdentifierContext): Seq[String] = withOrigin(ctx) {
+ visitMultipartIdentifier(ctx.multipartIdentifier)
+ }
+
+ override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) {
+ typedVisit[DataType](ctx.dataType)
+ }
+
+ override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = {
+ val schema = StructType(visitColTypeList(ctx.colTypeList))
+ withOrigin(ctx)(schema)
+ }
+
+ /* ********************************************************************************************
+ * Plan parsing
+ * ******************************************************************************************** */
+ protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree)
+
+ /**
+ * Create a top-level plan with Common Table Expressions.
+ */
+ override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) {
+ val query = plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses)
+
+ // Apply CTEs
+ query.optionalMap(ctx.ctes)(withCTE)
+ }
+
+ override def visitDmlStatement(ctx: DmlStatementContext): AnyRef = withOrigin(ctx) {
+ val dmlStmt = plan(ctx.dmlStatementNoWith)
+ // Apply CTEs
+ dmlStmt.optionalMap(ctx.ctes)(withCTE)
+ }
+
+ private def withCTE(ctx: CtesContext, plan: LogicalPlan): LogicalPlan = {
+ val ctes = ctx.namedQuery.asScala.map { nCtx =>
+ val namedQuery = visitNamedQuery(nCtx)
+ (namedQuery.alias, namedQuery)
+ }
+ // Check for duplicate names.
+ val duplicates = ctes.groupBy(_._1).filter(_._2.size > 1).keys
+ if (duplicates.nonEmpty) {
+ throw new ParseException(
+ s"CTE definition can't have duplicate names: ${duplicates.mkString("'", "', '", "'")}.",
+ ctx)
+ }
+ With(plan, ctes.toSeq)
+ }
+
+ /**
+ * Create a logical query plan for a hive-style FROM statement body.
+ */
+ private def withFromStatementBody(
+ ctx: FromStatementBodyContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // two cases for transforms and selects
+ if (ctx.transformClause != null) {
+ withTransformQuerySpecification(
+ ctx,
+ ctx.transformClause,
+ ctx.whereClause,
+ plan
+ )
+ } else {
+ withSelectQuerySpecification(
+ ctx,
+ ctx.selectClause,
+ ctx.lateralView,
+ ctx.whereClause,
+ ctx.aggregationClause,
+ ctx.havingClause,
+ ctx.windowClause,
+ plan
+ )
+ }
+ }
+
+ override def visitFromStatement(ctx: FromStatementContext): LogicalPlan = withOrigin(ctx) {
+ val from = visitFromClause(ctx.fromClause)
+ val selects = ctx.fromStatementBody.asScala.map { body =>
+ withFromStatementBody(body, from).
+ // Add organization statements.
+ optionalMap(body.queryOrganization)(withQueryResultClauses)
+ }
+ // If there are multiple SELECT just UNION them together into one query.
+ if (selects.length == 1) {
+ selects.head
+ } else {
+ Union(selects.toSeq)
+ }
+ }
+
+ /**
+ * Create a named logical plan.
+ *
+ * This is only used for Common Table Expressions.
+ */
+ override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) {
+ val subQuery: LogicalPlan = plan(ctx.query).optionalMap(ctx.columnAliases)(
+ (columnAliases, plan) =>
+ UnresolvedSubqueryColumnAliases(visitIdentifierList(columnAliases), plan)
+ )
+ SubqueryAlias(ctx.name.getText, subQuery)
+ }
+
+ /**
+ * Create a logical plan which allows for multiple inserts using one 'from' statement. These
+ * queries have the following SQL form:
+ * {{{
+ * [WITH cte...]?
+ * FROM src
+ * [INSERT INTO tbl1 SELECT *]+
+ * }}}
+ * For example:
+ * {{{
+ * FROM db.tbl1 A
+ * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5
+ * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12
+ * }}}
+ * This (Hive) feature cannot be combined with set-operators.
+ */
+ override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) {
+ val from = visitFromClause(ctx.fromClause)
+
+ // Build the insert clauses.
+ val inserts = ctx.multiInsertQueryBody.asScala.map { body =>
+ withInsertInto(body.insertInto,
+ withFromStatementBody(body.fromStatementBody, from).
+ optionalMap(body.fromStatementBody.queryOrganization)(withQueryResultClauses))
+ }
+
+ // If there are multiple INSERTS just UNION them together into one query.
+ if (inserts.length == 1) {
+ inserts.head
+ } else {
+ Union(inserts.toSeq)
+ }
+ }
+
+ /**
+ * Create a logical plan for a regular (single-insert) query.
+ */
+ override def visitSingleInsertQuery(
+ ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) {
+ withInsertInto(
+ ctx.insertInto(),
+ plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses))
+ }
+
+ /**
+ * Parameters used for writing query to a table:
+ * (multipartIdentifier, tableColumnList, partitionKeys, ifPartitionNotExists).
+ */
+ type InsertTableParams = (Seq[String], Seq[String], Map[String, Option[String]], Boolean)
+
+ /**
+ * Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider).
+ */
+ type InsertDirParams = (Boolean, CatalogStorageFormat, Option[String])
+
+ /**
+ * Add an
+ * {{{
+ * INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]? [identifierList]
+ * INSERT INTO [TABLE] tableIdentifier [partitionSpec] [identifierList]
+ * INSERT OVERWRITE [LOCAL] DIRECTORY STRING [rowFormat] [createFileFormat]
+ * INSERT OVERWRITE [LOCAL] DIRECTORY [STRING] tableProvider [OPTIONS tablePropertyList]
+ * }}}
+ * operation to logical plan
+ */
+ private def withInsertInto(
+ ctx: InsertIntoContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ ctx match {
+ case table: InsertIntoTableContext =>
+ val (tableIdent, cols, partition, ifPartitionNotExists) = visitInsertIntoTable(table)
+ InsertIntoStatement(
+ UnresolvedRelation(tableIdent),
+ partition,
+ cols,
+ query,
+ overwrite = false,
+ ifPartitionNotExists)
+ case table: InsertOverwriteTableContext =>
+ val (tableIdent, cols, partition, ifPartitionNotExists) = visitInsertOverwriteTable(table)
+ InsertIntoStatement(
+ UnresolvedRelation(tableIdent),
+ partition,
+ cols,
+ query,
+ overwrite = true,
+ ifPartitionNotExists)
+ case dir: InsertOverwriteDirContext =>
+ val (isLocal, storage, provider) = visitInsertOverwriteDir(dir)
+ InsertIntoDir(isLocal, storage, provider, query, overwrite = true)
+ case hiveDir: InsertOverwriteHiveDirContext =>
+ val (isLocal, storage, provider) = visitInsertOverwriteHiveDir(hiveDir)
+ InsertIntoDir(isLocal, storage, provider, query, overwrite = true)
+ case _ =>
+ throw new ParseException("Invalid InsertIntoContext", ctx)
+ }
+ }
+
+ /**
+ * Add an INSERT INTO TABLE operation to the logical plan.
+ */
+ override def visitInsertIntoTable(
+ ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) {
+ val tableIdent = visitMultipartIdentifier(ctx.multipartIdentifier)
+ val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil)
+ val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
+
+ if (ctx.EXISTS != null) {
+ operationNotAllowed("INSERT INTO ... IF NOT EXISTS", ctx)
+ }
+
+ (tableIdent, cols, partitionKeys, false)
+ }
+
+ /**
+ * Add an INSERT OVERWRITE TABLE operation to the logical plan.
+ */
+ override def visitInsertOverwriteTable(
+ ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) {
+ assert(ctx.OVERWRITE() != null)
+ val tableIdent = visitMultipartIdentifier(ctx.multipartIdentifier)
+ val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil)
+ val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
+
+ val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty)
+ if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) {
+ operationNotAllowed("IF NOT EXISTS with dynamic partitions: " +
+ dynamicPartitionKeys.keys.mkString(", "), ctx)
+ }
+
+ (tableIdent, cols, partitionKeys, ctx.EXISTS() != null)
+ }
+
+ /**
+ * Write to a directory, returning a [[InsertIntoDir]] logical plan.
+ */
+ override def visitInsertOverwriteDir(
+ ctx: InsertOverwriteDirContext): InsertDirParams = withOrigin(ctx) {
+ throw new ParseException("INSERT OVERWRITE DIRECTORY is not supported", ctx)
+ }
+
+ /**
+ * Write to a directory, returning a [[InsertIntoDir]] logical plan.
+ */
+ override def visitInsertOverwriteHiveDir(
+ ctx: InsertOverwriteHiveDirContext): InsertDirParams = withOrigin(ctx) {
+ throw new ParseException("INSERT OVERWRITE DIRECTORY is not supported", ctx)
+ }
+
+ private def getTableAliasWithoutColumnAlias(
+ ctx: TableAliasContext, op: String): Option[String] = {
+ if (ctx == null) {
+ None
+ } else {
+ val ident = ctx.strictIdentifier()
+ if (ctx.identifierList() != null) {
+ throw new ParseException(s"Columns aliases are not allowed in $op.", ctx.identifierList())
+ }
+ if (ident != null) Some(ident.getText) else None
+ }
+ }
+
+ override def visitDeleteFromTable(
+ ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) {
+ val table = UnresolvedRelation(visitMultipartIdentifier(ctx.multipartIdentifier()))
+ val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE")
+ val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
+ val predicate = if (ctx.whereClause() != null) {
+ Some(expression(ctx.whereClause().booleanExpression()))
+ } else {
+ None
+ }
+ DeleteFromTable(aliasedTable, predicate)
+ }
+
+ override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) {
+ val table = UnresolvedRelation(visitMultipartIdentifier(ctx.multipartIdentifier()))
+ val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE")
+ val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
+ val assignments = withAssignments(ctx.setClause().assignmentList())
+ val predicate = if (ctx.whereClause() != null) {
+ Some(expression(ctx.whereClause().booleanExpression()))
+ } else {
+ None
+ }
+
+ UpdateTable(aliasedTable, assignments, predicate)
+ }
+
+ private def withAssignments(assignCtx: AssignmentListContext): Seq[Assignment] =
+ withOrigin(assignCtx) {
+ assignCtx.assignment().asScala.map { assign =>
+ Assignment(UnresolvedAttribute(visitMultipartIdentifier(assign.key)),
+ expression(assign.value))
+ }.toSeq
+ }
+
+ override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) {
+ val targetTable = UnresolvedRelation(visitMultipartIdentifier(ctx.target))
+ val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE")
+ val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable)
+
+ val sourceTableOrQuery = if (ctx.source != null) {
+ UnresolvedRelation(visitMultipartIdentifier(ctx.source))
+ } else if (ctx.sourceQuery != null) {
+ visitQuery(ctx.sourceQuery)
+ } else {
+ throw new ParseException("Empty source for merge: you should specify a source" +
+ " table/subquery in merge.", ctx.source)
+ }
+ val sourceTableAlias = getTableAliasWithoutColumnAlias(ctx.sourceAlias, "MERGE")
+ val aliasedSource =
+ sourceTableAlias.map(SubqueryAlias(_, sourceTableOrQuery)).getOrElse(sourceTableOrQuery)
+
+ val mergeCondition = expression(ctx.mergeCondition)
+
+ val matchedActions = ctx.matchedClause().asScala.map {
+ clause => {
+ if (clause.matchedAction().DELETE() != null) {
+ DeleteAction(Option(clause.matchedCond).map(expression))
+ } else if (clause.matchedAction().UPDATE() != null) {
+ val condition = Option(clause.matchedCond).map(expression)
+ if (clause.matchedAction().ASTERISK() != null) {
+ UpdateAction(condition, Seq())
+ } else {
+ UpdateAction(condition, withAssignments(clause.matchedAction().assignmentList()))
+ }
+ } else {
+ // It should not be here.
+ throw new ParseException(
+ s"Unrecognized matched action: ${clause.matchedAction().getText}",
+ clause.matchedAction())
+ }
+ }
+ }
+ val notMatchedActions = ctx.notMatchedClause().asScala.map {
+ clause => {
+ if (clause.notMatchedAction().INSERT() != null) {
+ val condition = Option(clause.notMatchedCond).map(expression)
+ if (clause.notMatchedAction().ASTERISK() != null) {
+ InsertAction(condition, Seq())
+ } else {
+ val columns = clause.notMatchedAction().columns.multipartIdentifier()
+ .asScala.map(attr => UnresolvedAttribute(visitMultipartIdentifier(attr)))
+ val values = clause.notMatchedAction().expression().asScala.map(expression)
+ if (columns.size != values.size) {
+ throw new ParseException("The number of inserted values cannot match the fields.",
+ clause.notMatchedAction())
+ }
+ InsertAction(condition, columns.zip(values).map(kv => Assignment(kv._1, kv._2)).toSeq)
+ }
+ } else {
+ // It should not be here.
+ throw new ParseException(
+ s"Unrecognized not matched action: ${clause.notMatchedAction().getText}",
+ clause.notMatchedAction())
+ }
+ }
+ }
+ if (matchedActions.isEmpty && notMatchedActions.isEmpty) {
+ throw new ParseException("There must be at least one WHEN clause in a MERGE statement", ctx)
+ }
+ // children being empty means that the condition is not set
+ val matchedActionSize = matchedActions.length
+ if (matchedActionSize >= 2 && !matchedActions.init.forall(_.condition.nonEmpty)) {
+ throw new ParseException("When there are more than one MATCHED clauses in a MERGE " +
+ "statement, only the last MATCHED clause can omit the condition.", ctx)
+ }
+ val notMatchedActionSize = notMatchedActions.length
+ if (notMatchedActionSize >= 2 && !notMatchedActions.init.forall(_.condition.nonEmpty)) {
+ throw new ParseException("When there are more than one NOT MATCHED clauses in a MERGE " +
+ "statement, only the last NOT MATCHED clause can omit the condition.", ctx)
+ }
+
+ MergeIntoTable(
+ aliasedTarget,
+ aliasedSource,
+ mergeCondition,
+ matchedActions.toSeq,
+ notMatchedActions.toSeq)
+ }
+
+ /**
+ * Create a partition specification map.
+ */
+ override def visitPartitionSpec(
+ ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) {
+ val legacyNullAsString =
+ conf.getConf(SQLConf.LEGACY_PARSE_NULL_PARTITION_SPEC_AS_STRING_LITERAL)
+ val parts = ctx.partitionVal.asScala.map { pVal =>
+ val name = pVal.identifier.getText
+ val value = Option(pVal.constant).map(v => visitStringConstant(v, legacyNullAsString))
+ name -> value
+ }
+ // Before calling `toMap`, we check duplicated keys to avoid silently ignore partition values
+ // in partition spec like PARTITION(a='1', b='2', a='3'). The real semantical check for
+ // partition columns will be done in analyzer.
+ if (conf.caseSensitiveAnalysis) {
+ checkDuplicateKeys(parts.toSeq, ctx)
+ } else {
+ checkDuplicateKeys(parts.map(kv => kv._1.toLowerCase(Locale.ROOT) -> kv._2).toSeq, ctx)
+ }
+ parts.toMap
+ }
+
+ /**
+ * Create a partition specification map without optional values.
+ */
+ protected def visitNonOptionalPartitionSpec(
+ ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) {
+ visitPartitionSpec(ctx).map {
+ case (key, None) => throw new ParseException(s"Found an empty partition key '$key'.", ctx)
+ case (key, Some(value)) => key -> value
+ }
+ }
+
+ /**
+ * Convert a constant of any type into a string. This is typically used in DDL commands, and its
+ * main purpose is to prevent slight differences due to back to back conversions i.e.:
+ * String -> Literal -> String.
+ */
+ protected def visitStringConstant(
+ ctx: ConstantContext,
+ legacyNullAsString: Boolean): String = withOrigin(ctx) {
+ ctx match {
+ case _: NullLiteralContext if !legacyNullAsString => null
+ case s: StringLiteralContext => createString(s)
+ case o => o.getText
+ }
+ }
+
+ /**
+ * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These
+ * clauses determine the shape (ordering/partitioning/rows) of the query result.
+ */
+ private def withQueryResultClauses(
+ ctx: QueryOrganizationContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+
+ // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause.
+ val withOrder = if (
+ !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // ORDER BY ...
+ Sort(order.asScala.map(visitSortItem).toSeq, global = true, query)
+ } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // SORT BY ...
+ Sort(sort.asScala.map(visitSortItem).toSeq, global = false, query)
+ } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
+ // DISTRIBUTE BY ...
+ withRepartitionByExpression(ctx, expressionList(distributeBy), query)
+ } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
+ // SORT BY ... DISTRIBUTE BY ...
+ Sort(
+ sort.asScala.map(visitSortItem).toSeq,
+ global = false,
+ withRepartitionByExpression(ctx, expressionList(distributeBy), query))
+ } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) {
+ // CLUSTER BY ...
+ val expressions = expressionList(clusterBy)
+ Sort(
+ expressions.map(SortOrder(_, Ascending)),
+ global = false,
+ withRepartitionByExpression(ctx, expressions, query))
+ } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // [EMPTY]
+ query
+ } else {
+ throw new ParseException(
+ "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx)
+ }
+
+ // WINDOWS
+ val withWindow = withOrder.optionalMap(windowClause)(withWindowClause)
+
+ // LIMIT
+ // - LIMIT ALL is the same as omitting the LIMIT clause
+ withWindow.optional(limit) {
+ Limit(typedVisit(limit), withWindow)
+ }
+ }
+
+ /**
+ * Create a clause for DISTRIBUTE BY.
+ */
+ protected def withRepartitionByExpression(
+ ctx: QueryOrganizationContext,
+ expressions: Seq[Expression],
+ query: LogicalPlan): LogicalPlan = {
+ throw new ParseException("DISTRIBUTE BY is not supported", ctx)
+ }
+
+ override def visitTransformQuerySpecification(
+ ctx: TransformQuerySpecificationContext): LogicalPlan = withOrigin(ctx) {
+ val from = OneRowRelation().optional(ctx.fromClause) {
+ visitFromClause(ctx.fromClause)
+ }
+ withTransformQuerySpecification(ctx, ctx.transformClause, ctx.whereClause, from)
+ }
+
+ override def visitRegularQuerySpecification(
+ ctx: RegularQuerySpecificationContext): LogicalPlan = withOrigin(ctx) {
+ val from = OneRowRelation().optional(ctx.fromClause) {
+ visitFromClause(ctx.fromClause)
+ }
+ withSelectQuerySpecification(
+ ctx,
+ ctx.selectClause,
+ ctx.lateralView,
+ ctx.whereClause,
+ ctx.aggregationClause,
+ ctx.havingClause,
+ ctx.windowClause,
+ from
+ )
+ }
+
+ override def visitNamedExpressionSeq(
+ ctx: NamedExpressionSeqContext): Seq[Expression] = {
+ Option(ctx).toSeq
+ .flatMap(_.namedExpression.asScala)
+ .map(typedVisit[Expression])
+ }
+
+ /**
+ * Create a logical plan using a having clause.
+ */
+ private def withHavingClause(
+ ctx: HavingClauseContext, plan: LogicalPlan): LogicalPlan = {
+ // Note that we add a cast to non-predicate expressions. If the expression itself is
+ // already boolean, the optimizer will get rid of the unnecessary cast.
+ val predicate = expression(ctx.booleanExpression) match {
+ case p: Predicate => p
+ case e => Cast(e, BooleanType)
+ }
+ UnresolvedHaving(predicate, plan)
+ }
+
+ /**
+ * Create a logical plan using a where clause.
+ */
+ private def withWhereClause(ctx: WhereClauseContext, plan: LogicalPlan): LogicalPlan = {
+ Filter(expression(ctx.booleanExpression), plan)
+ }
+
+ /**
+ * Add a hive-style transform (SELECT TRANSFORM/MAP/REDUCE) query specification to a logical plan.
+ */
+ private def withTransformQuerySpecification(
+ ctx: ParserRuleContext,
+ transformClause: TransformClauseContext,
+ whereClause: WhereClauseContext,
+ relation: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // Add where.
+ val withFilter = relation.optionalMap(whereClause)(withWhereClause)
+
+ // Create the transform.
+ val expressions = visitNamedExpressionSeq(transformClause.namedExpressionSeq)
+
+ // Create the attributes.
+ val (attributes, schemaLess) = if (transformClause.colTypeList != null) {
+ // Typed return columns.
+ (createSchema(transformClause.colTypeList).toAttributes, false)
+ } else if (transformClause.identifierSeq != null) {
+ // Untyped return columns.
+ val attrs = visitIdentifierSeq(transformClause.identifierSeq).map { name =>
+ AttributeReference(name, StringType, nullable = true)()
+ }
+ (attrs, false)
+ } else {
+ (Seq(AttributeReference("key", StringType)(),
+ AttributeReference("value", StringType)()), true)
+ }
+
+ // Create the transform.
+ ScriptTransformation(
+ expressions,
+ string(transformClause.script),
+ attributes,
+ withFilter,
+ withScriptIOSchema(
+ ctx,
+ transformClause.inRowFormat,
+ transformClause.recordWriter,
+ transformClause.outRowFormat,
+ transformClause.recordReader,
+ schemaLess
+ )
+ )
+ }
+
+ /**
+ * Add a regular (SELECT) query specification to a logical plan. The query specification
+ * is the core of the logical plan, this is where sourcing (FROM clause), projection (SELECT),
+ * aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place.
+ *
+ * Note that query hints are ignored (both by the parser and the builder).
+ */
+ private def withSelectQuerySpecification(
+ ctx: ParserRuleContext,
+ selectClause: SelectClauseContext,
+ lateralView: java.util.List[LateralViewContext],
+ whereClause: WhereClauseContext,
+ aggregationClause: AggregationClauseContext,
+ havingClause: HavingClauseContext,
+ windowClause: WindowClauseContext,
+ relation: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // Add lateral views.
+ val withLateralView = lateralView.asScala.foldLeft(relation)(withGenerate)
+
+ // Add where.
+ val withFilter = withLateralView.optionalMap(whereClause)(withWhereClause)
+
+ val expressions = visitNamedExpressionSeq(selectClause.namedExpressionSeq)
+ // Add aggregation or a project.
+ val namedExpressions = expressions.map {
+ case e: NamedExpression => e
+ case e: Expression => UnresolvedAlias(e)
+ }
+
+ def createProject() = if (namedExpressions.nonEmpty) {
+ Project(namedExpressions, withFilter)
+ } else {
+ withFilter
+ }
+
+ val withProject = if (aggregationClause == null && havingClause != null) {
+ if (conf.getConf(SQLConf.LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE)) {
+ // If the legacy conf is set, treat HAVING without GROUP BY as WHERE.
+ val predicate = expression(havingClause.booleanExpression) match {
+ case p: Predicate => p
+ case e => Cast(e, BooleanType)
+ }
+ Filter(predicate, createProject())
+ } else {
+ // According to SQL standard, HAVING without GROUP BY means global aggregate.
+ withHavingClause(havingClause, Aggregate(Nil, namedExpressions, withFilter))
+ }
+ } else if (aggregationClause != null) {
+ val aggregate = withAggregationClause(aggregationClause, namedExpressions, withFilter)
+ aggregate.optionalMap(havingClause)(withHavingClause)
+ } else {
+ // When hitting this branch, `having` must be null.
+ createProject()
+ }
+
+ // Distinct
+ val withDistinct = if (
+ selectClause.setQuantifier() != null &&
+ selectClause.setQuantifier().DISTINCT() != null) {
+ Distinct(withProject)
+ } else {
+ withProject
+ }
+
+ // Window
+ val withWindow = withDistinct.optionalMap(windowClause)(withWindowClause)
+
+ // Hint
+ selectClause.hints.asScala.foldRight(withWindow)(withHints)
+ }
+
+ /**
+ * Create a (Hive based) [[ScriptInputOutputSchema]].
+ */
+ protected def withScriptIOSchema(
+ ctx: ParserRuleContext,
+ inRowFormat: RowFormatContext,
+ recordWriter: Token,
+ outRowFormat: RowFormatContext,
+ recordReader: Token,
+ schemaLess: Boolean): ScriptInputOutputSchema = {
+ throw new ParseException("Script Transform is not supported", ctx)
+ }
+
+ /**
+ * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma
+ * separated) relations here, these get converted into a single plan by condition-less inner join.
+ */
+ override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) {
+ val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) =>
+ val right = plan(relation.relationPrimary)
+ val join = right.optionalMap(left)(Join(_, _, Inner, None, JoinHint.NONE))
+ withJoinRelations(join, relation)
+ }
+ if (ctx.pivotClause() != null) {
+ if (!ctx.lateralView.isEmpty) {
+ throw new ParseException("LATERAL cannot be used together with PIVOT in FROM clause", ctx)
+ }
+ withPivot(ctx.pivotClause, from)
+ } else {
+ ctx.lateralView.asScala.foldLeft(from)(withGenerate)
+ }
+ }
+
+ /**
+ * Connect two queries by a Set operator.
+ *
+ * Supported Set operators are:
+ * - UNION [ DISTINCT | ALL ]
+ * - EXCEPT [ DISTINCT | ALL ]
+ * - MINUS [ DISTINCT | ALL ]
+ * - INTERSECT [DISTINCT | ALL]
+ */
+ override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) {
+ val left = plan(ctx.left)
+ val right = plan(ctx.right)
+ val all = Option(ctx.setQuantifier()).exists(_.ALL != null)
+ ctx.operator.getType match {
+ case HoodieSqlBaseParser.UNION if all =>
+ Union(left, right)
+ case HoodieSqlBaseParser.UNION =>
+ Distinct(Union(left, right))
+ case HoodieSqlBaseParser.INTERSECT if all =>
+ Intersect(left, right, isAll = true)
+ case HoodieSqlBaseParser.INTERSECT =>
+ Intersect(left, right, isAll = false)
+ case HoodieSqlBaseParser.EXCEPT if all =>
+ Except(left, right, isAll = true)
+ case HoodieSqlBaseParser.EXCEPT =>
+ Except(left, right, isAll = false)
+ case HoodieSqlBaseParser.SETMINUS if all =>
+ Except(left, right, isAll = true)
+ case HoodieSqlBaseParser.SETMINUS =>
+ Except(left, right, isAll = false)
+ }
+ }
+
+ /**
+ * Add a [[WithWindowDefinition]] operator to a logical plan.
+ */
+ private def withWindowClause(
+ ctx: WindowClauseContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // Collect all window specifications defined in the WINDOW clause.
+ val baseWindowTuples = ctx.namedWindow.asScala.map {
+ wCtx =>
+ (wCtx.name.getText, typedVisit[WindowSpec](wCtx.windowSpec))
+ }
+ baseWindowTuples.groupBy(_._1).foreach { kv =>
+ if (kv._2.size > 1) {
+ throw new ParseException(s"The definition of window '${kv._1}' is repetitive", ctx)
+ }
+ }
+ val baseWindowMap = baseWindowTuples.toMap
+
+ // Handle cases like
+ // window w1 as (partition by p_mfgr order by p_name
+ // range between 2 preceding and 2 following),
+ // w2 as w1
+ val windowMapView = baseWindowMap.mapValues {
+ case WindowSpecReference(name) =>
+ baseWindowMap.get(name) match {
+ case Some(spec: WindowSpecDefinition) =>
+ spec
+ case Some(ref) =>
+ throw new ParseException(s"Window reference '$name' is not a window specification", ctx)
+ case None =>
+ throw new ParseException(s"Cannot resolve window reference '$name'", ctx)
+ }
+ case spec: WindowSpecDefinition => spec
+ }
+
+ // Note that mapValues creates a view instead of materialized map. We force materialization by
+ // mapping over identity.
+ WithWindowDefinition(windowMapView.map(identity).toMap, query)
+ }
+
+ /**
+ * Add an [[Aggregate]] or [[GroupingSets]] to a logical plan.
+ */
+ private def withAggregationClause(
+ ctx: AggregationClauseContext,
+ selectExpressions: Seq[NamedExpression],
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ val groupByExpressions = expressionList(ctx.groupingExpressions)
+
+ if (ctx.GROUPING != null) {
+ // GROUP BY .... GROUPING SETS (...)
+ val selectedGroupByExprs =
+ ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)).toSeq)
+ GroupingSets(selectedGroupByExprs.toSeq, groupByExpressions, query, selectExpressions)
+ } else {
+ // GROUP BY .... (WITH CUBE | WITH ROLLUP)?
+ val mappedGroupByExpressions = if (ctx.CUBE != null) {
+ Seq(Cube(groupByExpressions))
+ } else if (ctx.ROLLUP != null) {
+ Seq(Rollup(groupByExpressions))
+ } else {
+ groupByExpressions
+ }
+ Aggregate(mappedGroupByExpressions, selectExpressions, query)
+ }
+ }
+
+ /**
+ * Add [[UnresolvedHint]]s to a logical plan.
+ */
+ private def withHints(
+ ctx: HintContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ var plan = query
+ ctx.hintStatements.asScala.reverse.foreach { stmt =>
+ plan = UnresolvedHint(stmt.hintName.getText,
+ stmt.parameters.asScala.map(expression).toSeq, plan)
+ }
+ plan
+ }
+
+ /**
+ * Add a [[Pivot]] to a logical plan.
+ */
+ private def withPivot(
+ ctx: PivotClauseContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ val aggregates = Option(ctx.aggregates).toSeq
+ .flatMap(_.namedExpression.asScala)
+ .map(typedVisit[Expression])
+ val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) {
+ UnresolvedAttribute.quoted(ctx.pivotColumn.identifier.getText)
+ } else {
+ CreateStruct(
+ ctx.pivotColumn.identifiers.asScala.map(
+ identifier => UnresolvedAttribute.quoted(identifier.getText)).toSeq)
+ }
+ val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue)
+ Pivot(None, pivotColumn, pivotValues.toSeq, aggregates, query)
+ }
+
+ /**
+ * Create a Pivot column value with or without an alias.
+ */
+ override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.expression)
+ if (ctx.identifier != null) {
+ Alias(e, ctx.identifier.getText)()
+ } else {
+ e
+ }
+ }
+
+ /**
+ * Add a [[Generate]] (Lateral View) to a logical plan.
+ */
+ private def withGenerate(
+ query: LogicalPlan,
+ ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) {
+ val expressions = expressionList(ctx.expression)
+ Generate(
+ UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions),
+ unrequiredChildIndex = Nil,
+ outer = ctx.OUTER != null,
+ // scalastyle:off caselocale
+ Some(ctx.tblName.getText.toLowerCase),
+ // scalastyle:on caselocale
+ ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply).toSeq,
+ query)
+ }
+
+ /**
+ * Create a single relation referenced in a FROM clause. This method is used when a part of the
+ * join condition is nested, for example:
+ * {{{
+ * select * from t1 join (t2 cross join t3) on col1 = col2
+ * }}}
+ */
+ override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) {
+ withJoinRelations(plan(ctx.relationPrimary), ctx)
+ }
+
+ /**
+ * Join one more [[LogicalPlan]]s to the current logical plan.
+ */
+ private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = {
+ ctx.joinRelation.asScala.foldLeft(base) { (left, join) =>
+ withOrigin(join) {
+ val baseJoinType = join.joinType match {
+ case null => Inner
+ case jt if jt.CROSS != null => Cross
+ case jt if jt.FULL != null => FullOuter
+ case jt if jt.SEMI != null => LeftSemi
+ case jt if jt.ANTI != null => LeftAnti
+ case jt if jt.LEFT != null => LeftOuter
+ case jt if jt.RIGHT != null => RightOuter
+ case _ => Inner
+ }
+
+ // Resolve the join type and join condition
+ val (joinType, condition) = Option(join.joinCriteria) match {
+ case Some(c) if c.USING != null =>
+ (UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None)
+ case Some(c) if c.booleanExpression != null =>
+ (baseJoinType, Option(expression(c.booleanExpression)))
+ case Some(c) =>
+ throw new ParseException(s"Unimplemented joinCriteria: $c", ctx)
+ case None if join.NATURAL != null =>
+ if (baseJoinType == Cross) {
+ throw new ParseException("NATURAL CROSS JOIN is not supported", ctx)
+ }
+ (NaturalJoin(baseJoinType), None)
+ case None =>
+ (baseJoinType, None)
+ }
+ Join(left, plan(join.right), joinType, condition, JoinHint.NONE)
+ }
+ }
+ }
+
+ /**
+ * Add a [[Sample]] to a logical plan.
+ *
+ * This currently supports the following sampling methods:
+ * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows.
+ * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages
+ * are defined as a number between 0 and 100.
+ * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction.
+ */
+ private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // Create a sampled plan if we need one.
+ def sample(fraction: Double): Sample = {
+ // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling
+ // function takes X PERCENT as the input and the range of X is [0, 100], we need to
+ // adjust the fraction.
+ val eps = RandomSampler.roundingEpsilon
+ validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps,
+ s"Sampling fraction ($fraction) must be on interval [0, 1]",
+ ctx)
+ Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)
+ }
+
+ if (ctx.sampleMethod() == null) {
+ throw new ParseException("TABLESAMPLE does not accept empty inputs.", ctx)
+ }
+
+ ctx.sampleMethod() match {
+ case ctx: SampleByRowsContext =>
+ Limit(expression(ctx.expression), query)
+
+ case ctx: SampleByPercentileContext =>
+ val fraction = ctx.percentage.getText.toDouble
+ val sign = if (ctx.negativeSign == null) 1 else -1
+ sample(sign * fraction / 100.0d)
+
+ case ctx: SampleByBytesContext =>
+ val bytesStr = ctx.bytes.getText
+ if (bytesStr.matches("[0-9]+[bBkKmMgG]")) {
+ throw new ParseException("TABLESAMPLE(byteLengthLiteral) is not supported", ctx)
+ } else {
+ throw new ParseException(
+ bytesStr + " is not a valid byte length literal, " +
+ "expected syntax: DIGIT+ ('B' | 'K' | 'M' | 'G')", ctx)
+ }
+
+ case ctx: SampleByBucketContext if ctx.ON() != null =>
+ if (ctx.identifier != null) {
+ throw new ParseException(
+ "TABLESAMPLE(BUCKET x OUT OF y ON colname) is not supported", ctx)
+ } else {
+ throw new ParseException(
+ "TABLESAMPLE(BUCKET x OUT OF y ON function) is not supported", ctx)
+ }
+
+ case ctx: SampleByBucketContext =>
+ sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble)
+ }
+ }
+
+ /**
+ * Create a logical plan for a sub-query.
+ */
+ override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.query)
+ }
+
+ /**
+ * Create an un-aliased table reference. This is typically used for top-level table references,
+ * for example:
+ * {{{
+ * INSERT INTO db.tbl2
+ * TABLE db.tbl1
+ * }}}
+ */
+ override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) {
+ UnresolvedRelation(visitMultipartIdentifier(ctx.multipartIdentifier))
+ }
+
+ /**
+ * Create a table-valued function call with arguments, e.g. range(1000)
+ */
+ override def visitTableValuedFunction(ctx: TableValuedFunctionContext)
+ : LogicalPlan = withOrigin(ctx) {
+ val func = ctx.functionTable
+ val aliases = if (func.tableAlias.identifierList != null) {
+ visitIdentifierList(func.tableAlias.identifierList)
+ } else {
+ Seq.empty
+ }
+
+ val tvf = UnresolvedTableValuedFunction(
+ func.funcName.getText, func.expression.asScala.map(expression).toSeq, aliases)
+ tvf.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan)
+ }
+
+ /**
+ * Create an inline table (a virtual table in Hive parlance).
+ */
+ override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) {
+ // Get the backing expressions.
+ val rows = ctx.expression.asScala.map { e =>
+ expression(e) match {
+ // inline table comes in two styles:
+ // style 1: values (1), (2), (3) -- multiple columns are supported
+ // style 2: values 1, 2, 3 -- only a single column is supported here
+ case struct: CreateNamedStruct => struct.valExprs // style 1
+ case child => Seq(child) // style 2
+ }
+ }
+
+ val aliases = if (ctx.tableAlias.identifierList != null) {
+ visitIdentifierList(ctx.tableAlias.identifierList)
+ } else {
+ Seq.tabulate(rows.head.size)(i => s"col${i + 1}")
+ }
+
+ val table = UnresolvedInlineTable(aliases, rows.toSeq)
+ table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a join relation. This is practically the same as
+ * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different
+ * hooks. We could add alias names for output columns, for example:
+ * {{{
+ * SELECT a, b, c, d FROM (src1 s1 INNER JOIN src2 s2 ON s1.id = s2.id) dst(a, b, c, d)
+ * }}}
+ */
+ override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) {
+ val relation = plan(ctx.relation).optionalMap(ctx.sample)(withSample)
+ mayApplyAliasPlan(ctx.tableAlias, relation)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as
+ * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different
+ * hooks. We could add alias names for output columns, for example:
+ * {{{
+ * SELECT col1, col2 FROM testData AS t(col1, col2)
+ * }}}
+ */
+ override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) {
+ val relation = plan(ctx.query).optionalMap(ctx.sample)(withSample)
+ if (ctx.tableAlias.strictIdentifier == null) {
+ // For un-aliased subqueries, use a default alias name that is not likely to conflict with
+ // normal subquery names, so that parent operators can only access the columns in subquery by
+ // unqualified names. Users can still use this special qualifier to access columns if they
+ // know it, but that's not recommended.
+ SubqueryAlias("__auto_generated_subquery_name", relation)
+ } else {
+ mayApplyAliasPlan(ctx.tableAlias, relation)
+ }
+ }
+
+ /**
+ * Create an alias ([[SubqueryAlias]]) for a [[LogicalPlan]].
+ */
+ private def aliasPlan(alias: ParserRuleContext, plan: LogicalPlan): LogicalPlan = {
+ SubqueryAlias(alias.getText, plan)
+ }
+
+ /**
+ * If aliases specified in a FROM clause, create a subquery alias ([[SubqueryAlias]]) and
+ * column aliases for a [[LogicalPlan]].
+ */
+ private def mayApplyAliasPlan(tableAlias: TableAliasContext, plan: LogicalPlan): LogicalPlan = {
+ if (tableAlias.strictIdentifier != null) {
+ val subquery = SubqueryAlias(tableAlias.strictIdentifier.getText, plan)
+ if (tableAlias.identifierList != null) {
+ val columnNames = visitIdentifierList(tableAlias.identifierList)
+ UnresolvedSubqueryColumnAliases(columnNames, subquery)
+ } else {
+ subquery
+ }
+ } else {
+ plan
+ }
+ }
+
+ /**
+ * Create a Sequence of Strings for a parenthesis enclosed alias list.
+ */
+ override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) {
+ visitIdentifierSeq(ctx.identifierSeq)
+ }
+
+ /**
+ * Create a Sequence of Strings for an identifier list.
+ */
+ override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) {
+ ctx.ident.asScala.map(_.getText).toSeq
+ }
+
+ /* ********************************************************************************************
+ * Table Identifier parsing
+ * ******************************************************************************************** */
+ /**
+ * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern.
+ */
+ override def visitTableIdentifier(
+ ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) {
+ TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText))
+ }
+
+ /**
+ * Create a [[FunctionIdentifier]] from a 'functionName' or 'databaseName'.'functionName' pattern.
+ */
+ override def visitFunctionIdentifier(
+ ctx: FunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) {
+ FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText))
+ }
+
+ /**
+ * Create a multi-part identifier.
+ */
+ override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] =
+ withOrigin(ctx) {
+ ctx.parts.asScala.map(_.getText).toSeq
+ }
+
+ /* ********************************************************************************************
+ * Expression parsing
+ * ******************************************************************************************** */
+ /**
+ * Create an expression from the given context. This method just passes the context on to the
+ * visitor and only takes care of typing (We assume that the visitor returns an Expression here).
+ */
+ protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx)
+
+ /**
+ * Create sequence of expressions from the given sequence of contexts.
+ */
+ private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = {
+ trees.asScala.map(expression).toSeq
+ }
+
+ /**
+ * Create a star (i.e. all) expression; this selects all elements (in the specified object).
+ * Both un-targeted (global) and targeted aliases are supported.
+ */
+ override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) {
+ UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText).toSeq))
+ }
+
+ /**
+ * Create an aliased expression if an alias is specified. Both single and multi-aliases are
+ * supported.
+ */
+ override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.expression)
+ if (ctx.name != null) {
+ Alias(e, ctx.name.getText)()
+ } else if (ctx.identifierList != null) {
+ MultiAlias(e, visitIdentifierList(ctx.identifierList))
+ } else {
+ e
+ }
+ }
+
+ /**
+ * Combine a number of boolean expressions into a balanced expression tree. These expressions are
+ * either combined by a logical [[And]] or a logical [[Or]].
+ *
+ * A balanced binary tree is created because regular left recursive trees cause considerable
+ * performance degradations and can cause stack overflows.
+ */
+ override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) {
+ val expressionType = ctx.operator.getType
+ val expressionCombiner = expressionType match {
+ case HoodieSqlBaseParser.AND => And.apply _
+ case HoodieSqlBaseParser.OR => Or.apply _
+ }
+
+ // Collect all similar left hand contexts.
+ val contexts = ArrayBuffer(ctx.right)
+ var current = ctx.left
+ def collectContexts: Boolean = current match {
+ case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType =>
+ contexts += lbc.right
+ current = lbc.left
+ true
+ case _ =>
+ contexts += current
+ false
+ }
+ while (collectContexts) {
+ // No body - all updates take place in the collectContexts.
+ }
+
+ // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them
+ // into expressions.
+ val expressions = contexts.reverseMap(expression)
+
+ // Create a balanced tree.
+ def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match {
+ case 0 =>
+ expressions(low)
+ case 1 =>
+ expressionCombiner(expressions(low), expressions(high))
+ case x =>
+ val mid = low + x / 2
+ expressionCombiner(
+ reduceToExpressionTree(low, mid),
+ reduceToExpressionTree(mid + 1, high))
+ }
+ reduceToExpressionTree(0, expressions.size - 1)
+ }
+
+ /**
+ * Invert a boolean expression.
+ */
+ override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) {
+ Not(expression(ctx.booleanExpression()))
+ }
+
+ /**
+ * Create a filtering correlated sub-query (EXISTS).
+ */
+ override def visitExists(ctx: ExistsContext): Expression = {
+ Exists(plan(ctx.query))
+ }
+
+ /**
+ * Create a comparison expression. This compares two expressions. The following comparison
+ * operators are supported:
+ * - Equal: '=' or '=='
+ * - Null-safe Equal: '<=>'
+ * - Not Equal: '<>' or '!='
+ * - Less than: '<'
+ * - Less then or Equal: '<='
+ * - Greater than: '>'
+ * - Greater then or Equal: '>='
+ */
+ override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) {
+ val left = expression(ctx.left)
+ val right = expression(ctx.right)
+ val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode]
+ operator.getSymbol.getType match {
+ case HoodieSqlBaseParser.EQ =>
+ EqualTo(left, right)
+ case HoodieSqlBaseParser.NSEQ =>
+ EqualNullSafe(left, right)
+ case HoodieSqlBaseParser.NEQ | HoodieSqlBaseParser.NEQJ =>
+ Not(EqualTo(left, right))
+ case HoodieSqlBaseParser.LT =>
+ LessThan(left, right)
+ case HoodieSqlBaseParser.LTE =>
+ LessThanOrEqual(left, right)
+ case HoodieSqlBaseParser.GT =>
+ GreaterThan(left, right)
+ case HoodieSqlBaseParser.GTE =>
+ GreaterThanOrEqual(left, right)
+ }
+ }
+
+ /**
+ * Create a predicated expression. A predicated expression is a normal expression with a
+ * predicate attached to it, for example:
+ * {{{
+ * a + 1 IS NULL
+ * }}}
+ */
+ override def visitPredicated(ctx: PredicatedContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.valueExpression)
+ if (ctx.predicate != null) {
+ withPredicate(e, ctx.predicate)
+ } else {
+ e
+ }
+ }
+
+ /**
+ * Add a predicate to the given expression. Supported expressions are:
+ * - (NOT) BETWEEN
+ * - (NOT) IN
+ * - (NOT) LIKE (ANY | SOME | ALL)
+ * - (NOT) RLIKE
+ * - IS (NOT) NULL.
+ * - IS (NOT) (TRUE | FALSE | UNKNOWN)
+ * - IS (NOT) DISTINCT FROM
+ */
+ private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) {
+ // Invert a predicate if it has a valid NOT clause.
+ def invertIfNotDefined(e: Expression): Expression = ctx.NOT match {
+ case null => e
+ case not => Not(e)
+ }
+
+ def getValueExpressions(e: Expression): Seq[Expression] = e match {
+ case c: CreateNamedStruct => c.valExprs
+ case other => Seq(other)
+ }
+
+ // Create the predicate.
+ ctx.kind.getType match {
+ case HoodieSqlBaseParser.BETWEEN =>
+ // BETWEEN is translated to lower <= e && e <= upper
+ invertIfNotDefined(And(
+ GreaterThanOrEqual(e, expression(ctx.lower)),
+ LessThanOrEqual(e, expression(ctx.upper))))
+ case HoodieSqlBaseParser.IN if ctx.query != null =>
+ invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query))))
+ case HoodieSqlBaseParser.IN =>
+ invertIfNotDefined(In(e, ctx.expression.asScala.map(expression).toSeq))
+ case HoodieSqlBaseParser.LIKE =>
+ Option(ctx.quantifier).map(_.getType) match {
+ case Some(HoodieSqlBaseParser.ANY) | Some(HoodieSqlBaseParser.SOME) =>
+ validate(!ctx.expression.isEmpty, "Expected something between '(' and ')'.", ctx)
+ val expressions = expressionList(ctx.expression)
+ if (expressions.forall(_.foldable) && expressions.forall(_.dataType == StringType)) {
+ // If there are many pattern expressions, will throw StackOverflowError.
+ // So we use LikeAny or NotLikeAny instead.
+ val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String])
+ ctx.NOT match {
+ case null => LikeAny(e, patterns.toSeq)
+ case _ => NotLikeAny(e, patterns.toSeq)
+ }
+ } else {
+ ctx.expression.asScala.map(expression)
+ .map(p => invertIfNotDefined(new Like(e, p))).toSeq.reduceLeft(Or)
+ }
+ case Some(HoodieSqlBaseParser.ALL) =>
+ validate(!ctx.expression.isEmpty, "Expected something between '(' and ')'.", ctx)
+ val expressions = expressionList(ctx.expression)
+ if (expressions.forall(_.foldable) && expressions.forall(_.dataType == StringType)) {
+ // If there are many pattern expressions, will throw StackOverflowError.
+ // So we use LikeAll or NotLikeAll instead.
+ val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String])
+ ctx.NOT match {
+ case null => LikeAll(e, patterns.toSeq)
+ case _ => NotLikeAll(e, patterns.toSeq)
+ }
+ } else {
+ ctx.expression.asScala.map(expression)
+ .map(p => invertIfNotDefined(new Like(e, p))).toSeq.reduceLeft(And)
+ }
+ case _ =>
+ val escapeChar = Option(ctx.escapeChar).map(string).map { str =>
+ if (str.length != 1) {
+ throw new ParseException("Invalid escape string." +
+ "Escape string must contain only one character.", ctx)
+ }
+ str.charAt(0)
+ }.getOrElse('\\')
+ invertIfNotDefined(Like(e, expression(ctx.pattern), escapeChar))
+ }
+ case HoodieSqlBaseParser.RLIKE =>
+ invertIfNotDefined(RLike(e, expression(ctx.pattern)))
+ case HoodieSqlBaseParser.NULL if ctx.NOT != null =>
+ IsNotNull(e)
+ case HoodieSqlBaseParser.NULL =>
+ IsNull(e)
+ case HoodieSqlBaseParser.TRUE => ctx.NOT match {
+ case null => EqualNullSafe(e, Literal(true))
+ case _ => Not(EqualNullSafe(e, Literal(true)))
+ }
+ case HoodieSqlBaseParser.FALSE => ctx.NOT match {
+ case null => EqualNullSafe(e, Literal(false))
+ case _ => Not(EqualNullSafe(e, Literal(false)))
+ }
+ case HoodieSqlBaseParser.UNKNOWN => ctx.NOT match {
+ case null => IsUnknown(e)
+ case _ => IsNotUnknown(e)
+ }
+ case HoodieSqlBaseParser.DISTINCT if ctx.NOT != null =>
+ EqualNullSafe(e, expression(ctx.right))
+ case HoodieSqlBaseParser.DISTINCT =>
+ Not(EqualNullSafe(e, expression(ctx.right)))
+ }
+ }
+
+ /**
+ * Create a binary arithmetic expression. The following arithmetic operators are supported:
+ * - Multiplication: '*'
+ * - Division: '/'
+ * - Hive Long Division: 'DIV'
+ * - Modulo: '%'
+ * - Addition: '+'
+ * - Subtraction: '-'
+ * - Binary AND: '&'
+ * - Binary XOR
+ * - Binary OR: '|'
+ */
+ override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) {
+ val left = expression(ctx.left)
+ val right = expression(ctx.right)
+ ctx.operator.getType match {
+ case HoodieSqlBaseParser.ASTERISK =>
+ Multiply(left, right)
+ case HoodieSqlBaseParser.SLASH =>
+ Divide(left, right)
+ case HoodieSqlBaseParser.PERCENT =>
+ Remainder(left, right)
+ case HoodieSqlBaseParser.DIV =>
+ IntegralDivide(left, right)
+ case HoodieSqlBaseParser.PLUS =>
+ Add(left, right)
+ case HoodieSqlBaseParser.MINUS =>
+ Subtract(left, right)
+ case HoodieSqlBaseParser.CONCAT_PIPE =>
+ Concat(left :: right :: Nil)
+ case HoodieSqlBaseParser.AMPERSAND =>
+ BitwiseAnd(left, right)
+ case HoodieSqlBaseParser.HAT =>
+ BitwiseXor(left, right)
+ case HoodieSqlBaseParser.PIPE =>
+ BitwiseOr(left, right)
+ }
+ }
+
+ /**
+ * Create a unary arithmetic expression. The following arithmetic operators are supported:
+ * - Plus: '+'
+ * - Minus: '-'
+ * - Bitwise Not: '~'
+ */
+ override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) {
+ val value = expression(ctx.valueExpression)
+ ctx.operator.getType match {
+ case HoodieSqlBaseParser.PLUS =>
+ UnaryPositive(value)
+ case HoodieSqlBaseParser.MINUS =>
+ UnaryMinus(value)
+ case HoodieSqlBaseParser.TILDE =>
+ BitwiseNot(value)
+ }
+ }
+
+ override def visitCurrentDatetime(ctx: CurrentDatetimeContext): Expression = withOrigin(ctx) {
+ if (conf.ansiEnabled) {
+ ctx.name.getType match {
+ case HoodieSqlBaseParser.CURRENT_DATE =>
+ CurrentDate()
+ case HoodieSqlBaseParser.CURRENT_TIMESTAMP =>
+ CurrentTimestamp()
+ }
+ } else {
+ // If the parser is not in ansi mode, we should return `UnresolvedAttribute`, in case there
+ // are columns named `CURRENT_DATE` or `CURRENT_TIMESTAMP`.
+ UnresolvedAttribute.quoted(ctx.name.getText)
+ }
+ }
+
+ /**
+ * Create a [[Cast]] expression.
+ */
+ override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) {
+ val rawDataType = typedVisit[DataType](ctx.dataType())
+ val dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType)
+ Cast(expression(ctx.expression), dataType)
+ }
+
+ /**
+ * Create a [[CreateStruct]] expression.
+ */
+ override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) {
+ CreateStruct.create(ctx.argument.asScala.map(expression).toSeq)
+ }
+
+ /**
+ * Create a [[First]] expression.
+ */
+ override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) {
+ val ignoreNullsExpr = ctx.IGNORE != null
+ First(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression()
+ }
+
+ /**
+ * Create a [[Last]] expression.
+ */
+ override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) {
+ val ignoreNullsExpr = ctx.IGNORE != null
+ Last(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression()
+ }
+
+ /**
+ * Create a Position expression.
+ */
+ override def visitPosition(ctx: PositionContext): Expression = withOrigin(ctx) {
+ new StringLocate(expression(ctx.substr), expression(ctx.str))
+ }
+
+ /**
+ * Create a Extract expression.
+ */
+ override def visitExtract(ctx: ExtractContext): Expression = withOrigin(ctx) {
+ val arguments = Seq(Literal(ctx.field.getText), expression(ctx.source))
+ UnresolvedFunction("extract", arguments, isDistinct = false)
+ }
+
+ /**
+ * Create a Substring/Substr expression.
+ */
+ override def visitSubstring(ctx: SubstringContext): Expression = withOrigin(ctx) {
+ if (ctx.len != null) {
+ Substring(expression(ctx.str), expression(ctx.pos), expression(ctx.len))
+ } else {
+ new Substring(expression(ctx.str), expression(ctx.pos))
+ }
+ }
+
+ /**
+ * Create a Trim expression.
+ */
+ override def visitTrim(ctx: TrimContext): Expression = withOrigin(ctx) {
+ val srcStr = expression(ctx.srcStr)
+ val trimStr = Option(ctx.trimStr).map(expression)
+ Option(ctx.trimOption).map(_.getType).getOrElse(HoodieSqlBaseParser.BOTH) match {
+ case HoodieSqlBaseParser.BOTH =>
+ StringTrim(srcStr, trimStr)
+ case HoodieSqlBaseParser.LEADING =>
+ StringTrimLeft(srcStr, trimStr)
+ case HoodieSqlBaseParser.TRAILING =>
+ StringTrimRight(srcStr, trimStr)
+ case other =>
+ throw new ParseException("Function trim doesn't support with " +
+ s"type $other. Please use BOTH, LEADING or TRAILING as trim type", ctx)
+ }
+ }
+
+ /**
+ * Create a Overlay expression.
+ */
+ override def visitOverlay(ctx: OverlayContext): Expression = withOrigin(ctx) {
+ val input = expression(ctx.input)
+ val replace = expression(ctx.replace)
+ val position = expression(ctx.position)
+ val lengthOpt = Option(ctx.length).map(expression)
+ lengthOpt match {
+ case Some(length) => Overlay(input, replace, position, length)
+ case None => new Overlay(input, replace, position)
+ }
+ }
+
+ /**
+ * Create a (windowed) Function expression.
+ */
+ override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) {
+ // Create the function call.
+ val name = ctx.functionName.getText
+ val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null)
+ // Call `toSeq`, otherwise `ctx.argument.asScala.map(expression)` is `Buffer` in Scala 2.13
+ val arguments = ctx.argument.asScala.map(expression).toSeq match {
+ case Seq(UnresolvedStar(None))
+ if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct =>
+ // Transform COUNT(*) into COUNT(1).
+ Seq(Literal(1))
+ case expressions =>
+ expressions
+ }
+ val filter = Option(ctx.where).map(expression(_))
+ val function = UnresolvedFunction(
+ getFunctionIdentifier(ctx.functionName), arguments, isDistinct, filter)
+
+ // Check if the function is evaluated in a windowed context.
+ ctx.windowSpec match {
+ case spec: WindowRefContext =>
+ UnresolvedWindowExpression(function, visitWindowRef(spec))
+ case spec: WindowDefContext =>
+ WindowExpression(function, visitWindowDef(spec))
+ case _ => function
+ }
+ }
+
+
+ /**
+ * Create a function database (optional) and name pair, for multipartIdentifier.
+ * This is used in CREATE FUNCTION, DROP FUNCTION, SHOWFUNCTIONS.
+ */
+ protected def visitFunctionName(ctx: MultipartIdentifierContext): FunctionIdentifier = {
+ visitFunctionName(ctx, ctx.parts.asScala.map(_.getText).toSeq)
+ }
+
+ /**
+ * Create a function database (optional) and name pair.
+ */
+ protected def visitFunctionName(ctx: QualifiedNameContext): FunctionIdentifier = {
+ visitFunctionName(ctx, ctx.identifier().asScala.map(_.getText).toSeq)
+ }
+
+ /**
+ * Create a function database (optional) and name pair.
+ */
+ private def visitFunctionName(ctx: ParserRuleContext, texts: Seq[String]): FunctionIdentifier = {
+ texts match {
+ case Seq(db, fn) => FunctionIdentifier(fn, Option(db))
+ case Seq(fn) => FunctionIdentifier(fn, None)
+ case other =>
+ throw new ParseException(s"Unsupported function name '${texts.mkString(".")}'", ctx)
+ }
+ }
+
+ /**
+ * Get a function identifier consist by database (optional) and name.
+ */
+ protected def getFunctionIdentifier(ctx: FunctionNameContext): FunctionIdentifier = {
+ if (ctx.qualifiedName != null) {
+ visitFunctionName(ctx.qualifiedName)
+ } else {
+ FunctionIdentifier(ctx.getText, None)
+ }
+ }
+
+ /**
+ * Create an [[LambdaFunction]].
+ */
+ override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) {
+ val arguments = ctx.identifier().asScala.map { name =>
+ UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts)
+ }
+ val function = expression(ctx.expression).transformUp {
+ case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts)
+ }
+ LambdaFunction(function, arguments.toSeq)
+ }
+
+ /**
+ * Create a reference to a window frame, i.e. [[WindowSpecReference]].
+ */
+ override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) {
+ WindowSpecReference(ctx.name.getText)
+ }
+
+ /**
+ * Create a window definition, i.e. [[WindowSpecDefinition]].
+ */
+ override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) {
+ // CLUSTER BY ... | PARTITION BY ... ORDER BY ...
+ val partition = ctx.partition.asScala.map(expression)
+ val order = ctx.sortItem.asScala.map(visitSortItem)
+
+ // RANGE/ROWS BETWEEN ...
+ val frameSpecOption = Option(ctx.windowFrame).map { frame =>
+ val frameType = frame.frameType.getType match {
+ case HoodieSqlBaseParser.RANGE => RangeFrame
+ case HoodieSqlBaseParser.ROWS => RowFrame
+ }
+
+ SpecifiedWindowFrame(
+ frameType,
+ visitFrameBound(frame.start),
+ Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow))
+ }
+
+ WindowSpecDefinition(
+ partition.toSeq,
+ order.toSeq,
+ frameSpecOption.getOrElse(UnspecifiedFrame))
+ }
+
+ /**
+ * Create or resolve a frame boundary expressions.
+ */
+ override def visitFrameBound(ctx: FrameBoundContext): Expression = withOrigin(ctx) {
+ def value: Expression = {
+ val e = expression(ctx.expression)
+ validate(e.resolved && e.foldable, "Frame bound value must be a literal.", ctx)
+ e
+ }
+
+ ctx.boundType.getType match {
+ case HoodieSqlBaseParser.PRECEDING if ctx.UNBOUNDED != null =>
+ UnboundedPreceding
+ case HoodieSqlBaseParser.PRECEDING =>
+ UnaryMinus(value)
+ case HoodieSqlBaseParser.CURRENT =>
+ CurrentRow
+ case HoodieSqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null =>
+ UnboundedFollowing
+ case HoodieSqlBaseParser.FOLLOWING =>
+ value
+ }
+ }
+
+ /**
+ * Create a [[CreateStruct]] expression.
+ */
+ override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) {
+ CreateStruct(ctx.namedExpression().asScala.map(expression).toSeq)
+ }
+
+ /**
+ * Create a [[ScalarSubquery]] expression.
+ */
+ override def visitSubqueryExpression(
+ ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) {
+ ScalarSubquery(plan(ctx.query))
+ }
+
+ /**
+ * Create a value based [[CaseWhen]] expression. This has the following SQL form:
+ * {{{
+ * CASE [expression]
+ * WHEN [value] THEN [expression]
+ * ...
+ * ELSE [expression]
+ * END
+ * }}}
+ */
+ override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.value)
+ val branches = ctx.whenClause.asScala.map { wCtx =>
+ (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result))
+ }
+ CaseWhen(branches.toSeq, Option(ctx.elseExpression).map(expression))
+ }
+
+ /**
+ * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax:
+ * {{{
+ * CASE
+ * WHEN [predicate] THEN [expression]
+ * ...
+ * ELSE [expression]
+ * END
+ * }}}
+ *
+ * @param ctx the parse tree
+ * */
+ override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) {
+ val branches = ctx.whenClause.asScala.map { wCtx =>
+ (expression(wCtx.condition), expression(wCtx.result))
+ }
+ CaseWhen(branches.toSeq, Option(ctx.elseExpression).map(expression))
+ }
+
+ /**
+ * Currently only regex in expressions of SELECT statements are supported; in other
+ * places, e.g., where `(a)?+.+` = 2, regex are not meaningful.
+ */
+ private def canApplyRegex(ctx: ParserRuleContext): Boolean = withOrigin(ctx) {
+ var parent = ctx.getParent
+ var rtn = false
+ while (parent != null) {
+ if (parent.isInstanceOf[NamedExpressionContext]) {
+ rtn = true
+ }
+ parent = parent.getParent
+ }
+ rtn
+ }
+
+ /**
+ * Create a dereference expression. The return type depends on the type of the parent.
+ * If the parent is an [[UnresolvedAttribute]], it can be a [[UnresolvedAttribute]] or
+ * a [[UnresolvedRegex]] for regex quoted in ``; if the parent is some other expression,
+ * it can be [[UnresolvedExtractValue]].
+ */
+ override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) {
+ val attr = ctx.fieldName.getText
+ expression(ctx.base) match {
+ case unresolved_attr @ UnresolvedAttribute(nameParts) =>
+ ctx.fieldName.getStart.getText match {
+ case escapedIdentifier(columnNameRegex)
+ if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) =>
+ UnresolvedRegex(columnNameRegex, Some(unresolved_attr.name),
+ conf.caseSensitiveAnalysis)
+ case _ =>
+ UnresolvedAttribute(nameParts :+ attr)
+ }
+ case e =>
+ UnresolvedExtractValue(e, Literal(attr))
+ }
+ }
+
+ /**
+ * Create an [[UnresolvedAttribute]] expression or a [[UnresolvedRegex]] if it is a regex
+ * quoted in ``
+ */
+ override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) {
+ ctx.getStart.getText match {
+ case escapedIdentifier(columnNameRegex)
+ if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) =>
+ UnresolvedRegex(columnNameRegex, None, conf.caseSensitiveAnalysis)
+ case _ =>
+ UnresolvedAttribute.quoted(ctx.getText)
+ }
+
+ }
+
+ /**
+ * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array.
+ */
+ override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) {
+ UnresolvedExtractValue(expression(ctx.value), expression(ctx.index))
+ }
+
+ /**
+ * Create an expression for an expression between parentheses. This is need because the ANTLR
+ * visitor cannot automatically convert the nested context into an expression.
+ */
+ override def visitParenthesizedExpression(
+ ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) {
+ expression(ctx.expression)
+ }
+
+ /**
+ * Create a [[SortOrder]] expression.
+ */
+ override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) {
+ val direction = if (ctx.DESC != null) {
+ Descending
+ } else {
+ Ascending
+ }
+ val nullOrdering = if (ctx.FIRST != null) {
+ NullsFirst
+ } else if (ctx.LAST != null) {
+ NullsLast
+ } else {
+ direction.defaultNullOrdering
+ }
+ SortOrder(expression(ctx.expression), direction, nullOrdering, Seq.empty)
+ }
+
+ /**
+ * Create a typed Literal expression. A typed literal has the following SQL syntax:
+ * {{{
+ * [TYPE] '[VALUE]'
+ * }}}
+ * Currently Date, Timestamp, Interval and Binary typed literals are supported.
+ */
+ override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) {
+ val value = string(ctx.STRING)
+ val valueType = ctx.identifier.getText.toUpperCase(Locale.ROOT)
+
+ def toLiteral[T](f: UTF8String => Option[T], t: DataType): Literal = {
+ f(UTF8String.fromString(value)).map(Literal(_, t)).getOrElse {
+ throw new ParseException(s"Cannot parse the $valueType value: $value", ctx)
+ }
+ }
+ try {
+ valueType match {
+ case "DATE" =>
+ toLiteral(stringToDate(_, getZoneId(SQLConf.get.sessionLocalTimeZone)), DateType)
+ case "TIMESTAMP" =>
+ val zoneId = getZoneId(SQLConf.get.sessionLocalTimeZone)
+ toLiteral(stringToTimestamp(_, zoneId), TimestampType)
+ case "INTERVAL" =>
+ val interval = try {
+ IntervalUtils.stringToInterval(UTF8String.fromString(value))
+ } catch {
+ case e: IllegalArgumentException =>
+ val ex = new ParseException("Cannot parse the INTERVAL value: " + value, ctx)
+ ex.setStackTrace(e.getStackTrace)
+ throw ex
+ }
+ Literal(interval, CalendarIntervalType)
+ case "X" =>
+ val padding = if (value.length % 2 != 0) "0" else ""
+ Literal(DatatypeConverter.parseHexBinary(padding + value))
+ case other =>
+ throw new ParseException(s"Literals of type '$other' are currently not" +
+ " supported.", ctx)
+ }
+ } catch {
+ case e: IllegalArgumentException =>
+ val message = Option(e.getMessage).getOrElse(s"Exception parsing $valueType")
+ throw new ParseException(message, ctx)
+ }
+ }
+
+ /**
+ * Create a NULL literal expression.
+ */
+ override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) {
+ Literal(null)
+ }
+
+ /**
+ * Create a Boolean literal expression.
+ */
+ override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) {
+ if (ctx.getText.toBoolean) {
+ Literal.TrueLiteral
+ } else {
+ Literal.FalseLiteral
+ }
+ }
+
+ /**
+ * Create an integral literal expression. The code selects the most narrow integral type
+ * possible, either a BigDecimal, a Long or an Integer is returned.
+ */
+ override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) {
+ BigDecimal(ctx.getText) match {
+ case v if v.isValidInt =>
+ Literal(v.intValue)
+ case v if v.isValidLong =>
+ Literal(v.longValue)
+ case v => Literal(v.underlying())
+ }
+ }
+
+ /**
+ * Create a decimal literal for a regular decimal number.
+ */
+ override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) {
+ Literal(BigDecimal(ctx.getText).underlying())
+ }
+
+ /**
+ * Create a decimal literal for a regular decimal number or a scientific decimal number.
+ */
+ override def visitLegacyDecimalLiteral(
+ ctx: LegacyDecimalLiteralContext): Literal = withOrigin(ctx) {
+ Literal(BigDecimal(ctx.getText).underlying())
+ }
+
+ /**
+ * Create a double literal for number with an exponent, e.g. 1E-30
+ */
+ override def visitExponentLiteral(ctx: ExponentLiteralContext): Literal = {
+ numericLiteral(ctx, ctx.getText, /* exponent values don't have a suffix */
+ Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble)
+ }
+
+ /** Create a numeric literal expression. */
+ private def numericLiteral(
+ ctx: NumberContext,
+ rawStrippedQualifier: String,
+ minValue: BigDecimal,
+ maxValue: BigDecimal,
+ typeName: String)(converter: String => Any): Literal = withOrigin(ctx) {
+ try {
+ val rawBigDecimal = BigDecimal(rawStrippedQualifier)
+ if (rawBigDecimal < minValue || rawBigDecimal > maxValue) {
+ throw new ParseException(s"Numeric literal ${rawStrippedQualifier} does not " +
+ s"fit in range [${minValue}, ${maxValue}] for type ${typeName}", ctx)
+ }
+ Literal(converter(rawStrippedQualifier))
+ } catch {
+ case e: NumberFormatException =>
+ throw new ParseException(e.getMessage, ctx)
+ }
+ }
+
+ /**
+ * Create a Byte Literal expression.
+ */
+ override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = {
+ val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1)
+ numericLiteral(ctx, rawStrippedQualifier,
+ Byte.MinValue, Byte.MaxValue, ByteType.simpleString)(_.toByte)
+ }
+
+ /**
+ * Create a Short Literal expression.
+ */
+ override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = {
+ val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1)
+ numericLiteral(ctx, rawStrippedQualifier,
+ Short.MinValue, Short.MaxValue, ShortType.simpleString)(_.toShort)
+ }
+
+ /**
+ * Create a Long Literal expression.
+ */
+ override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = {
+ val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1)
+ numericLiteral(ctx, rawStrippedQualifier,
+ Long.MinValue, Long.MaxValue, LongType.simpleString)(_.toLong)
+ }
+
+ /**
+ * Create a Float Literal expression.
+ */
+ override def visitFloatLiteral(ctx: FloatLiteralContext): Literal = {
+ val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1)
+ numericLiteral(ctx, rawStrippedQualifier,
+ Float.MinValue, Float.MaxValue, FloatType.simpleString)(_.toFloat)
+ }
+
+ /**
+ * Create a Double Literal expression.
+ */
+ override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = {
+ val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1)
+ numericLiteral(ctx, rawStrippedQualifier,
+ Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble)
+ }
+
+ /**
+ * Create a BigDecimal Literal expression.
+ */
+ override def visitBigDecimalLiteral(ctx: BigDecimalLiteralContext): Literal = {
+ val raw = ctx.getText.substring(0, ctx.getText.length - 2)
+ try {
+ Literal(BigDecimal(raw).underlying())
+ } catch {
+ case e: AnalysisException =>
+ throw new ParseException(e.message, ctx)
+ }
+ }
+
+ /**
+ * Create a String literal expression.
+ */
+ override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) {
+ Literal(createString(ctx))
+ }
+
+ /**
+ * Create a String from a string literal context. This supports multiple consecutive string
+ * literals, these are concatenated, for example this expression "'hello' 'world'" will be
+ * converted into "helloworld".
+ *
+ * Special characters can be escaped by using Hive/C-style escaping.
+ */
+ private def createString(ctx: StringLiteralContext): String = {
+ if (conf.escapedStringLiterals) {
+ ctx.STRING().asScala.map(stringWithoutUnescape).mkString
+ } else {
+ ctx.STRING().asScala.map(string).mkString
+ }
+ }
+
+ /**
+ * Create a [[CalendarInterval]] literal expression. Two syntaxes are supported:
+ * - multiple unit value pairs, for instance: interval 2 months 2 days.
+ * - from-to unit, for instance: interval '1-2' year to month.
+ */
+ override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) {
+ Literal(parseIntervalLiteral(ctx), CalendarIntervalType)
+ }
+
+ /**
+ * Create a [[CalendarInterval]] object
+ */
+ protected def parseIntervalLiteral(ctx: IntervalContext): CalendarInterval = withOrigin(ctx) {
+ if (ctx.errorCapturingMultiUnitsInterval != null) {
+ val innerCtx = ctx.errorCapturingMultiUnitsInterval
+ if (innerCtx.unitToUnitInterval != null) {
+ throw new ParseException(
+ "Can only have a single from-to unit in the interval literal syntax",
+ innerCtx.unitToUnitInterval)
+ }
+ visitMultiUnitsInterval(innerCtx.multiUnitsInterval)
+ } else if (ctx.errorCapturingUnitToUnitInterval != null) {
+ val innerCtx = ctx.errorCapturingUnitToUnitInterval
+ if (innerCtx.error1 != null || innerCtx.error2 != null) {
+ val errorCtx = if (innerCtx.error1 != null) innerCtx.error1 else innerCtx.error2
+ throw new ParseException(
+ "Can only have a single from-to unit in the interval literal syntax",
+ errorCtx)
+ }
+ visitUnitToUnitInterval(innerCtx.body)
+ } else {
+ throw new ParseException("at least one time unit should be given for interval literal", ctx)
+ }
+ }
+
+ /**
+ * Creates a [[CalendarInterval]] with multiple unit value pairs, e.g. 1 YEAR 2 DAYS.
+ */
+ override def visitMultiUnitsInterval(ctx: MultiUnitsIntervalContext): CalendarInterval = {
+ withOrigin(ctx) {
+ val units = ctx.unit.asScala
+ val values = ctx.intervalValue().asScala
+ try {
+ assert(units.length == values.length)
+ val kvs = units.indices.map { i =>
+ val u = units(i).getText
+ val v = if (values(i).STRING() != null) {
+ val value = string(values(i).STRING())
+ // SPARK-32840: For invalid cases, e.g. INTERVAL '1 day 2' hour,
+ // INTERVAL 'interval 1' day, we need to check ahead before they are concatenated with
+ // units and become valid ones, e.g. '1 day 2 hour'.
+ // Ideally, we only ensure the value parts don't contain any units here.
+ if (value.exists(Character.isLetter)) {
+ throw new ParseException("Can only use numbers in the interval value part for" +
+ s" multiple unit value pairs interval form, but got invalid value: $value", ctx)
+ }
+ value
+ } else {
+ values(i).getText
+ }
+ UTF8String.fromString(" " + v + " " + u)
+ }
+ IntervalUtils.stringToInterval(UTF8String.concat(kvs: _*))
+ } catch {
+ case i: IllegalArgumentException =>
+ val e = new ParseException(i.getMessage, ctx)
+ e.setStackTrace(i.getStackTrace)
+ throw e
+ }
+ }
+ }
+
+ /**
+ * Creates a [[CalendarInterval]] with from-to unit, e.g. '2-1' YEAR TO MONTH.
+ */
+ override def visitUnitToUnitInterval(ctx: UnitToUnitIntervalContext): CalendarInterval = {
+ withOrigin(ctx) {
+ val value = Option(ctx.intervalValue.STRING).map(string).getOrElse {
+ throw new ParseException("The value of from-to unit must be a string", ctx.intervalValue)
+ }
+ try {
+ val from = ctx.from.getText.toLowerCase(Locale.ROOT)
+ val to = ctx.to.getText.toLowerCase(Locale.ROOT)
+ (from, to) match {
+ case ("year", "month") =>
+ IntervalUtils.fromYearMonthString(value)
+ case ("day", "hour") =>
+ IntervalUtils.fromDayTimeString(value, IntervalUnit.DAY, IntervalUnit.HOUR)
+ case ("day", "minute") =>
+ IntervalUtils.fromDayTimeString(value, IntervalUnit.DAY, IntervalUnit.MINUTE)
+ case ("day", "second") =>
+ IntervalUtils.fromDayTimeString(value, IntervalUnit.DAY, IntervalUnit.SECOND)
+ case ("hour", "minute") =>
+ IntervalUtils.fromDayTimeString(value, IntervalUnit.HOUR, IntervalUnit.MINUTE)
+ case ("hour", "second") =>
+ IntervalUtils.fromDayTimeString(value, IntervalUnit.HOUR, IntervalUnit.SECOND)
+ case ("minute", "second") =>
+ IntervalUtils.fromDayTimeString(value, IntervalUnit.MINUTE, IntervalUnit.SECOND)
+ case _ =>
+ throw new ParseException(s"Intervals FROM $from TO $to are not supported.", ctx)
+ }
+ } catch {
+ // Handle Exceptions thrown by CalendarInterval
+ case e: IllegalArgumentException =>
+ val pe = new ParseException(e.getMessage, ctx)
+ pe.setStackTrace(e.getStackTrace)
+ throw pe
+ }
+ }
+ }
+
+ /* ********************************************************************************************
+ * DataType parsing
+ * ******************************************************************************************** */
+
+ /**
+ * Resolve/create a primitive type.
+ */
+ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) {
+ val dataType = ctx.identifier.getText.toLowerCase(Locale.ROOT)
+ (dataType, ctx.INTEGER_VALUE().asScala.toList) match {
+ case ("boolean", Nil) => BooleanType
+ case ("tinyint" | "byte", Nil) => ByteType
+ case ("smallint" | "short", Nil) => ShortType
+ case ("int" | "integer", Nil) => IntegerType
+ case ("bigint" | "long", Nil) => LongType
+ case ("float" | "real", Nil) => FloatType
+ case ("double", Nil) => DoubleType
+ case ("date", Nil) => DateType
+ case ("timestamp", Nil) => TimestampType
+ case ("string", Nil) => StringType
+ case ("character" | "char", length :: Nil) => CharType(length.getText.toInt)
+ case ("varchar", length :: Nil) => VarcharType(length.getText.toInt)
+ case ("binary", Nil) => BinaryType
+ case ("decimal" | "dec" | "numeric", Nil) => DecimalType.USER_DEFAULT
+ case ("decimal" | "dec" | "numeric", precision :: Nil) =>
+ DecimalType(precision.getText.toInt, 0)
+ case ("decimal" | "dec" | "numeric", precision :: scale :: Nil) =>
+ DecimalType(precision.getText.toInt, scale.getText.toInt)
+ case ("void", Nil) => NullType
+ case ("interval", Nil) => CalendarIntervalType
+ case (dt, params) =>
+ val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt
+ throw new ParseException(s"DataType $dtStr is not supported.", ctx)
+ }
+ }
+
+ /**
+ * Create a complex DataType. Arrays, Maps and Structures are supported.
+ */
+ override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) {
+ ctx.complex.getType match {
+ case HoodieSqlBaseParser.ARRAY =>
+ ArrayType(typedVisit(ctx.dataType(0)))
+ case HoodieSqlBaseParser.MAP =>
+ MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1)))
+ case HoodieSqlBaseParser.STRUCT =>
+ StructType(Option(ctx.complexColTypeList).toSeq.flatMap(visitComplexColTypeList))
+ }
+ }
+
+ /**
+ * Create top level table schema.
+ */
+ protected def createSchema(ctx: ColTypeListContext): StructType = {
+ StructType(Option(ctx).toSeq.flatMap(visitColTypeList))
+ }
+
+ /**
+ * Create a [[StructType]] from a number of column definitions.
+ */
+ override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) {
+ ctx.colType().asScala.map(visitColType).toSeq
+ }
+
+ /**
+ * Create a top level [[StructField]] from a column definition.
+ */
+ override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) {
+ import ctx._
+
+ val builder = new MetadataBuilder
+ // Add comment to metadata
+ Option(commentSpec()).map(visitCommentSpec).foreach {
+ builder.putString("comment", _)
+ }
+
+ StructField(
+ name = colName.getText,
+ dataType = typedVisit[DataType](ctx.dataType),
+ nullable = NULL == null,
+ metadata = builder.build())
+ }
+
+ /**
+ * Create a [[StructType]] from a sequence of [[StructField]]s.
+ */
+ protected def createStructType(ctx: ComplexColTypeListContext): StructType = {
+ StructType(Option(ctx).toSeq.flatMap(visitComplexColTypeList))
+ }
+
+ /**
+ * Create a [[StructType]] from a number of column definitions.
+ */
+ override def visitComplexColTypeList(
+ ctx: ComplexColTypeListContext): Seq[StructField] = withOrigin(ctx) {
+ ctx.complexColType().asScala.map(visitComplexColType).toSeq
+ }
+
+ /**
+ * Create a [[StructField]] from a column definition.
+ */
+ override def visitComplexColType(ctx: ComplexColTypeContext): StructField = withOrigin(ctx) {
+ import ctx._
+ val structField = StructField(
+ name = identifier.getText,
+ dataType = typedVisit(dataType()),
+ nullable = NULL == null)
+ Option(commentSpec).map(visitCommentSpec).map(structField.withComment).getOrElse(structField)
+ }
+
+ /**
+ * Create a location string.
+ */
+ override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) {
+ string(ctx.STRING)
+ }
+
+ /**
+ * Create an optional location string.
+ */
+ protected def visitLocationSpecList(ctx: java.util.List[LocationSpecContext]): Option[String] = {
+ ctx.asScala.headOption.map(visitLocationSpec)
+ }
+
+ /**
+ * Create a comment string.
+ */
+ override def visitCommentSpec(ctx: CommentSpecContext): String = withOrigin(ctx) {
+ string(ctx.STRING)
+ }
+
+ /**
+ * Create an optional comment string.
+ */
+ protected def visitCommentSpecList(ctx: java.util.List[CommentSpecContext]): Option[String] = {
+ ctx.asScala.headOption.map(visitCommentSpec)
+ }
+
+ /**
+ * Create a [[BucketSpec]].
+ */
+ override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) {
+ BucketSpec(
+ ctx.INTEGER_VALUE.getText.toInt,
+ visitIdentifierList(ctx.identifierList),
+ Option(ctx.orderedIdentifierList)
+ .toSeq
+ .flatMap(_.orderedIdentifier.asScala)
+ .map { orderedIdCtx =>
+ Option(orderedIdCtx.ordering).map(_.getText).foreach { dir =>
+ if (dir.toLowerCase(Locale.ROOT) != "asc") {
+ operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx)
+ }
+ }
+
+ orderedIdCtx.ident.getText
+ })
+ }
+
+ /**
+ * Convert a table property list into a key-value map.
+ * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]].
+ */
+ override def visitTablePropertyList(
+ ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) {
+ val properties = ctx.tableProperty.asScala.map { property =>
+ val key = visitTablePropertyKey(property.key)
+ val value = visitTablePropertyValue(property.value)
+ key -> value
+ }
+ // Check for duplicate property names.
+ checkDuplicateKeys(properties.toSeq, ctx)
+ properties.toMap
+ }
+
+ /**
+ * Parse a key-value map from a [[TablePropertyListContext]], assuming all values are specified.
+ */
+ def visitPropertyKeyValues(ctx: TablePropertyListContext): Map[String, String] = {
+ val props = visitTablePropertyList(ctx)
+ val badKeys = props.collect { case (key, null) => key }
+ if (badKeys.nonEmpty) {
+ operationNotAllowed(
+ s"Values must be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx)
+ }
+ props
+ }
+
+ /**
+ * Parse a list of keys from a [[TablePropertyListContext]], assuming no values are specified.
+ */
+ def visitPropertyKeys(ctx: TablePropertyListContext): Seq[String] = {
+ val props = visitTablePropertyList(ctx)
+ val badKeys = props.filter { case (_, v) => v != null }.keys
+ if (badKeys.nonEmpty) {
+ operationNotAllowed(
+ s"Values should not be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx)
+ }
+ props.keys.toSeq
+ }
+
+ /**
+ * A table property key can either be String or a collection of dot separated elements. This
+ * function extracts the property key based on whether its a string literal or a table property
+ * identifier.
+ */
+ override def visitTablePropertyKey(key: TablePropertyKeyContext): String = {
+ if (key.STRING != null) {
+ string(key.STRING)
+ } else {
+ key.getText
+ }
+ }
+
+ /**
+ * A table property value can be String, Integer, Boolean or Decimal. This function extracts
+ * the property value based on whether its a string, integer, boolean or decimal literal.
+ */
+ override def visitTablePropertyValue(value: TablePropertyValueContext): String = {
+ if (value == null) {
+ null
+ } else if (value.STRING != null) {
+ string(value.STRING)
+ } else if (value.booleanValue != null) {
+ value.getText.toLowerCase(Locale.ROOT)
+ } else {
+ value.getText
+ }
+ }
+
+ /**
+ * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal).
+ */
+ type TableHeader = (Seq[String], Boolean, Boolean, Boolean)
+
+ /**
+ * Type to keep track of table clauses:
+ * - partition transforms
+ * - partition columns
+ * - bucketSpec
+ * - properties
+ * - options
+ * - location
+ * - comment
+ * - serde
+ *
+ * Note: Partition transforms are based on existing table schema definition. It can be simple
+ * column names, or functions like `year(date_col)`. Partition columns are column names with data
+ * types like `i INT`, which should be appended to the existing table schema.
+ */
+ type TableClauses = (
+ Seq[Transform], Seq[StructField], Option[BucketSpec], Map[String, String],
+ Map[String, String], Option[String], Option[String], Option[SerdeInfo])
+
+ /**
+ * Validate a create table statement and return the [[TableIdentifier]].
+ */
+ override def visitCreateTableHeader(
+ ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) {
+ val temporary = ctx.TEMPORARY != null
+ val ifNotExists = ctx.EXISTS != null
+ if (temporary && ifNotExists) {
+ operationNotAllowed("CREATE TEMPORARY TABLE ... IF NOT EXISTS", ctx)
+ }
+ val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText).toSeq
+ (multipartIdentifier, temporary, ifNotExists, ctx.EXTERNAL != null)
+ }
+
+ /**
+ * Validate a replace table statement and return the [[TableIdentifier]].
+ */
+ override def visitReplaceTableHeader(
+ ctx: ReplaceTableHeaderContext): TableHeader = withOrigin(ctx) {
+ val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText).toSeq
+ (multipartIdentifier, false, false, false)
+ }
+
+ /**
+ * Parse a qualified name to a multipart name.
+ */
+ override def visitQualifiedName(ctx: QualifiedNameContext): Seq[String] = withOrigin(ctx) {
+ ctx.identifier.asScala.map(_.getText).toSeq
+ }
+
+ /**
+ * Parse a list of transforms or columns.
+ */
+ override def visitPartitionFieldList(
+ ctx: PartitionFieldListContext): (Seq[Transform], Seq[StructField]) = withOrigin(ctx) {
+ val (transforms, columns) = ctx.fields.asScala.map {
+ case transform: PartitionTransformContext =>
+ (Some(visitPartitionTransform(transform)), None)
+ case field: PartitionColumnContext =>
+ (None, Some(visitColType(field.colType)))
+ }.unzip
+
+ (transforms.flatten.toSeq, columns.flatten.toSeq)
+ }
+
+ override def visitPartitionTransform(
+ ctx: PartitionTransformContext): Transform = withOrigin(ctx) {
+ def getFieldReference(
+ ctx: ApplyTransformContext,
+ arg: V2Expression): FieldReference = {
+ lazy val name: String = ctx.identifier.getText
+ arg match {
+ case ref: FieldReference =>
+ ref
+ case nonRef =>
+ throw new ParseException(
+ s"Expected a column reference for transform $name: ${nonRef.describe}", ctx)
+ }
+ }
+
+ def getSingleFieldReference(
+ ctx: ApplyTransformContext,
+ arguments: Seq[V2Expression]): FieldReference = {
+ lazy val name: String = ctx.identifier.getText
+ if (arguments.size > 1) {
+ throw new ParseException(s"Too many arguments for transform $name", ctx)
+ } else if (arguments.isEmpty) {
+ throw new ParseException(s"Not enough arguments for transform $name", ctx)
+ } else {
+ getFieldReference(ctx, arguments.head)
+ }
+ }
+
+ ctx.transform match {
+ case identityCtx: IdentityTransformContext =>
+ IdentityTransform(FieldReference(typedVisit[Seq[String]](identityCtx.qualifiedName)))
+
+ case applyCtx: ApplyTransformContext =>
+ val arguments = applyCtx.argument.asScala.map(visitTransformArgument).toSeq
+
+ applyCtx.identifier.getText match {
+ case "bucket" =>
+ val numBuckets: Int = arguments.head match {
+ case LiteralValue(shortValue, ShortType) =>
+ shortValue.asInstanceOf[Short].toInt
+ case LiteralValue(intValue, IntegerType) =>
+ intValue.asInstanceOf[Int]
+ case LiteralValue(longValue, LongType) =>
+ longValue.asInstanceOf[Long].toInt
+ case lit =>
+ throw new ParseException(s"Invalid number of buckets: ${lit.describe}", applyCtx)
+ }
+
+ val fields = arguments.tail.map(arg => getFieldReference(applyCtx, arg))
+
+ BucketTransform(LiteralValue(numBuckets, IntegerType), fields.toSeq)
+
+ case "years" =>
+ YearsTransform(getSingleFieldReference(applyCtx, arguments))
+
+ case "months" =>
+ MonthsTransform(getSingleFieldReference(applyCtx, arguments))
+
+ case "days" =>
+ DaysTransform(getSingleFieldReference(applyCtx, arguments))
+
+ case "hours" =>
+ HoursTransform(getSingleFieldReference(applyCtx, arguments))
+
+ case name =>
+ ApplyTransform(name, arguments)
+ }
+ }
+ }
+
+ /**
+ * Parse an argument to a transform. An argument may be a field reference (qualified name) or
+ * a value literal.
+ */
+ override def visitTransformArgument(ctx: TransformArgumentContext): V2Expression = {
+ withOrigin(ctx) {
+ val reference = Option(ctx.qualifiedName)
+ .map(typedVisit[Seq[String]])
+ .map(FieldReference(_))
+ val literal = Option(ctx.constant)
+ .map(typedVisit[Literal])
+ .map(lit => LiteralValue(lit.value, lit.dataType))
+ reference.orElse(literal)
+ .getOrElse(throw new ParseException(s"Invalid transform argument", ctx))
+ }
+ }
+
+ def cleanTableProperties(
+ ctx: ParserRuleContext, properties: Map[String, String]): Map[String, String] = {
+ import TableCatalog._
+ val legacyOn = conf.getConf(SQLConf.LEGACY_PROPERTY_NON_RESERVED)
+ properties.filter {
+ case (PROP_PROVIDER, _) if !legacyOn =>
+ throw new ParseException(s"$PROP_PROVIDER is a reserved table property, please use" +
+ s" the USING clause to specify it.", ctx)
+ case (PROP_PROVIDER, _) => false
+ case (PROP_LOCATION, _) if !legacyOn =>
+ throw new ParseException(s"$PROP_LOCATION is a reserved table property, please use" +
+ s" the LOCATION clause to specify it.", ctx)
+ case (PROP_LOCATION, _) => false
+ case (PROP_OWNER, _) if !legacyOn =>
+ throw new ParseException(s"$PROP_OWNER is a reserved table property, it will be" +
+ s" set to the current user", ctx)
+ case (PROP_OWNER, _) => false
+ case _ => true
+ }
+ }
+
+ def cleanTableOptions(
+ ctx: ParserRuleContext,
+ options: Map[String, String],
+ location: Option[String]): (Map[String, String], Option[String]) = {
+ var path = location
+ val filtered = cleanTableProperties(ctx, options).filter {
+ case (k, v) if k.equalsIgnoreCase("path") && path.nonEmpty =>
+ throw new ParseException(s"Duplicated table paths found: '${path.get}' and '$v'. LOCATION" +
+ s" and the case insensitive key 'path' in OPTIONS are all used to indicate the custom" +
+ s" table path, you can only specify one of them.", ctx)
+ case (k, v) if k.equalsIgnoreCase("path") =>
+ path = Some(v)
+ false
+ case _ => true
+ }
+ (filtered, path)
+ }
+
+ /**
+ * Create a [[SerdeInfo]] for creating tables.
+ *
+ * Format: STORED AS (name | INPUTFORMAT input_format OUTPUTFORMAT output_format)
+ */
+ override def visitCreateFileFormat(ctx: CreateFileFormatContext): SerdeInfo = withOrigin(ctx) {
+ (ctx.fileFormat, ctx.storageHandler) match {
+ // Expected format: INPUTFORMAT input_format OUTPUTFORMAT output_format
+ case (c: TableFileFormatContext, null) =>
+ SerdeInfo(formatClasses = Some(FormatClasses(string(c.inFmt), string(c.outFmt))))
+ // Expected format: SEQUENCEFILE | TEXTFILE | RCFILE | ORC | PARQUET | AVRO
+ case (c: GenericFileFormatContext, null) =>
+ SerdeInfo(storedAs = Some(c.identifier.getText))
+ case (null, storageHandler) =>
+ operationNotAllowed("STORED BY", ctx)
+ case _ =>
+ throw new ParseException("Expected either STORED AS or STORED BY, not both", ctx)
+ }
+ }
+
+ /**
+ * Create a [[SerdeInfo]] used for creating tables.
+ *
+ * Example format:
+ * {{{
+ * SERDE serde_name [WITH SERDEPROPERTIES (k1=v1, k2=v2, ...)]
+ * }}}
+ *
+ * OR
+ *
+ * {{{
+ * DELIMITED [FIELDS TERMINATED BY char [ESCAPED BY char]]
+ * [COLLECTION ITEMS TERMINATED BY char]
+ * [MAP KEYS TERMINATED BY char]
+ * [LINES TERMINATED BY char]
+ * [NULL DEFINED AS char]
+ * }}}
+ */
+ def visitRowFormat(ctx: RowFormatContext): SerdeInfo = withOrigin(ctx) {
+ ctx match {
+ case serde: RowFormatSerdeContext => visitRowFormatSerde(serde)
+ case delimited: RowFormatDelimitedContext => visitRowFormatDelimited(delimited)
+ }
+ }
+
+ /**
+ * Create SERDE row format name and properties pair.
+ */
+ override def visitRowFormatSerde(ctx: RowFormatSerdeContext): SerdeInfo = withOrigin(ctx) {
+ import ctx._
+ SerdeInfo(
+ serde = Some(string(name)),
+ serdeProperties = Option(tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty))
+ }
+
+ /**
+ * Create a delimited row format properties object.
+ */
+ override def visitRowFormatDelimited(
+ ctx: RowFormatDelimitedContext): SerdeInfo = withOrigin(ctx) {
+ // Collect the entries if any.
+ def entry(key: String, value: Token): Seq[(String, String)] = {
+ Option(value).toSeq.map(x => key -> string(x))
+ }
+ // TODO we need proper support for the NULL format.
+ val entries =
+ entry("field.delim", ctx.fieldsTerminatedBy) ++
+ entry("serialization.format", ctx.fieldsTerminatedBy) ++
+ entry("escape.delim", ctx.escapedBy) ++
+ // The following typo is inherited from Hive...
+ entry("colelction.delim", ctx.collectionItemsTerminatedBy) ++
+ entry("mapkey.delim", ctx.keysTerminatedBy) ++
+ Option(ctx.linesSeparatedBy).toSeq.map { token =>
+ val value = string(token)
+ validate(
+ value == "\n",
+ s"LINES TERMINATED BY only supports newline '\\n' right now: $value",
+ ctx)
+ "line.delim" -> value
+ }
+ SerdeInfo(serdeProperties = entries.toMap)
+ }
+
+ /**
+ * Throw a [[ParseException]] if the user specified incompatible SerDes through ROW FORMAT
+ * and STORED AS.
+ *
+ * The following are allowed. Anything else is not:
+ * ROW FORMAT SERDE ... STORED AS [SEQUENCEFILE | RCFILE | TEXTFILE]
+ * ROW FORMAT DELIMITED ... STORED AS TEXTFILE
+ * ROW FORMAT ... STORED AS INPUTFORMAT ... OUTPUTFORMAT ...
+ */
+ protected def validateRowFormatFileFormat(
+ rowFormatCtx: RowFormatContext,
+ createFileFormatCtx: CreateFileFormatContext,
+ parentCtx: ParserRuleContext): Unit = {
+ if (!(rowFormatCtx == null || createFileFormatCtx == null)) {
+ (rowFormatCtx, createFileFormatCtx.fileFormat) match {
+ case (_, ffTable: TableFileFormatContext) => // OK
+ case (rfSerde: RowFormatSerdeContext, ffGeneric: GenericFileFormatContext) =>
+ ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match {
+ case ("sequencefile" | "textfile" | "rcfile") => // OK
+ case fmt =>
+ operationNotAllowed(
+ s"ROW FORMAT SERDE is incompatible with format '$fmt', which also specifies a serde",
+ parentCtx)
+ }
+ case (rfDelimited: RowFormatDelimitedContext, ffGeneric: GenericFileFormatContext) =>
+ ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match {
+ case "textfile" => // OK
+ case fmt => operationNotAllowed(
+ s"ROW FORMAT DELIMITED is only compatible with 'textfile', not '$fmt'", parentCtx)
+ }
+ case _ =>
+ // should never happen
+ def str(ctx: ParserRuleContext): String = {
+ (0 until ctx.getChildCount).map { i => ctx.getChild(i).getText }.mkString(" ")
+ }
+
+ operationNotAllowed(
+ s"Unexpected combination of ${str(rowFormatCtx)} and ${str(createFileFormatCtx)}",
+ parentCtx)
+ }
+ }
+ }
+
+ protected def validateRowFormatFileFormat(
+ rowFormatCtx: Seq[RowFormatContext],
+ createFileFormatCtx: Seq[CreateFileFormatContext],
+ parentCtx: ParserRuleContext): Unit = {
+ if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) {
+ validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx)
+ }
+ }
+
+ override def visitCreateTableClauses(ctx: CreateTableClausesContext): TableClauses = {
+ checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx)
+ checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx)
+ checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx)
+ checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx)
+ checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx)
+ checkDuplicateClauses(ctx.commentSpec(), "COMMENT", ctx)
+ checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx)
+ checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx)
+
+ if (ctx.skewSpec.size > 0) {
+ operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx)
+ }
+
+ val (partTransforms, partCols) =
+ Option(ctx.partitioning).map(visitPartitionFieldList).getOrElse((Nil, Nil))
+ val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec)
+ val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty)
+ val cleanedProperties = cleanTableProperties(ctx, properties)
+ val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty)
+ val location = visitLocationSpecList(ctx.locationSpec())
+ val (cleanedOptions, newLocation) = cleanTableOptions(ctx, options, location)
+ val comment = visitCommentSpecList(ctx.commentSpec())
+ val serdeInfo =
+ getSerdeInfo(ctx.rowFormat.asScala.toSeq, ctx.createFileFormat.asScala.toSeq, ctx)
+ (partTransforms, partCols, bucketSpec, cleanedProperties, cleanedOptions, newLocation, comment,
+ serdeInfo)
+ }
+
+ protected def getSerdeInfo(
+ rowFormatCtx: Seq[RowFormatContext],
+ createFileFormatCtx: Seq[CreateFileFormatContext],
+ ctx: ParserRuleContext,
+ skipCheck: Boolean = false): Option[SerdeInfo] = {
+ if (!skipCheck) validateRowFormatFileFormat(rowFormatCtx, createFileFormatCtx, ctx)
+ val rowFormatSerdeInfo = rowFormatCtx.map(visitRowFormat)
+ val fileFormatSerdeInfo = createFileFormatCtx.map(visitCreateFileFormat)
+ (fileFormatSerdeInfo ++ rowFormatSerdeInfo).reduceLeftOption((l, r) => l.merge(r))
+ }
+
+ private def partitionExpressions(
+ partTransforms: Seq[Transform],
+ partCols: Seq[StructField],
+ ctx: ParserRuleContext): Seq[Transform] = {
+ if (partTransforms.nonEmpty) {
+ if (partCols.nonEmpty) {
+ val references = partTransforms.map(_.describe()).mkString(", ")
+ val columns = partCols
+ .map(field => s"${field.name} ${field.dataType.simpleString}")
+ .mkString(", ")
+ operationNotAllowed(
+ s"""PARTITION BY: Cannot mix partition expressions and partition columns:
+ |Expressions: $references
+ |Columns: $columns""".stripMargin, ctx)
+
+ }
+ partTransforms
+ } else {
+ // columns were added to create the schema. convert to column references
+ partCols.map { column =>
+ IdentityTransform(FieldReference(Seq(column.name)))
+ }
+ }
+ }
+
+ /**
+ * Create a table, returning a [[CreateTableStatement]] logical plan.
+ *
+ * Expected format:
+ * {{{
+ * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name
+ * [USING table_provider]
+ * create_table_clauses
+ * [[AS] select_statement];
+ *
+ * create_table_clauses (order insensitive):
+ * [PARTITIONED BY (partition_fields)]
+ * [OPTIONS table_property_list]
+ * [ROW FORMAT row_format]
+ * [STORED AS file_format]
+ * [CLUSTERED BY (col_name, col_name, ...)
+ * [SORTED BY (col_name [ASC|DESC], ...)]
+ * INTO num_buckets BUCKETS
+ * ]
+ * [LOCATION path]
+ * [COMMENT table_comment]
+ * [TBLPROPERTIES (property_name=property_value, ...)]
+ *
+ * partition_fields:
+ * col_name, transform(col_name), transform(constant, col_name), ... |
+ * col_name data_type [NOT NULL] [COMMENT col_comment], ...
+ * }}}
+ */
+ override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) {
+ val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader)
+
+ val columns = Option(ctx.colTypeList()).map(visitColTypeList).getOrElse(Nil)
+ val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText)
+ val (partTransforms, partCols, bucketSpec, properties, options, location, comment, serdeInfo) =
+ visitCreateTableClauses(ctx.createTableClauses())
+
+ if (provider.isDefined && serdeInfo.isDefined) {
+ operationNotAllowed(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx)
+ }
+
+ if (temp) {
+ val asSelect = if (ctx.query == null) "" else " AS ..."
+ operationNotAllowed(
+ s"CREATE TEMPORARY TABLE ...$asSelect, use CREATE TEMPORARY VIEW instead", ctx)
+ }
+
+ val partitioning = partitionExpressions(partTransforms, partCols, ctx)
+
+ Option(ctx.query).map(plan) match {
+ case Some(_) if columns.nonEmpty =>
+ operationNotAllowed(
+ "Schema may not be specified in a Create Table As Select (CTAS) statement",
+ ctx)
+
+ case Some(_) if partCols.nonEmpty =>
+ // non-reference partition columns are not allowed because schema can't be specified
+ operationNotAllowed(
+ "Partition column types may not be specified in Create Table As Select (CTAS)",
+ ctx)
+
+ case Some(query) =>
+ CreateTableAsSelectStatement(
+ table, query, partitioning, bucketSpec, properties, provider, options, location, comment,
+ writeOptions = Map.empty, serdeInfo, external = external, ifNotExists = ifNotExists)
+
+ case _ =>
+ // Note: table schema includes both the table columns list and the partition columns
+ // with data type.
+ val schema = StructType(columns ++ partCols)
+ CreateTableStatement(table, schema, partitioning, bucketSpec, properties, provider,
+ options, location, comment, serdeInfo, external = external, ifNotExists = ifNotExists)
+ }
+ }
+
+ /**
+ * Replace a table, returning a [[ReplaceTableStatement]] logical plan.
+ *
+ * Expected format:
+ * {{{
+ * [CREATE OR] REPLACE TABLE [db_name.]table_name
+ * [USING table_provider]
+ * replace_table_clauses
+ * [[AS] select_statement];
+ *
+ * replace_table_clauses (order insensitive):
+ * [OPTIONS table_property_list]
+ * [PARTITIONED BY (partition_fields)]
+ * [CLUSTERED BY (col_name, col_name, ...)
+ * [SORTED BY (col_name [ASC|DESC], ...)]
+ * INTO num_buckets BUCKETS
+ * ]
+ * [LOCATION path]
+ * [COMMENT table_comment]
+ * [TBLPROPERTIES (property_name=property_value, ...)]
+ *
+ * partition_fields:
+ * col_name, transform(col_name), transform(constant, col_name), ... |
+ * col_name data_type [NOT NULL] [COMMENT col_comment], ...
+ * }}}
+ */
+ override def visitReplaceTable(ctx: ReplaceTableContext): LogicalPlan = withOrigin(ctx) {
+ val (table, temp, ifNotExists, external) = visitReplaceTableHeader(ctx.replaceTableHeader)
+ val orCreate = ctx.replaceTableHeader().CREATE() != null
+
+ if (temp) {
+ val action = if (orCreate) "CREATE OR REPLACE" else "REPLACE"
+ operationNotAllowed(s"$action TEMPORARY TABLE ..., use $action TEMPORARY VIEW instead.", ctx)
+ }
+
+ if (external) {
+ operationNotAllowed("REPLACE EXTERNAL TABLE ...", ctx)
+ }
+
+ if (ifNotExists) {
+ operationNotAllowed("REPLACE ... IF NOT EXISTS, use CREATE IF NOT EXISTS instead", ctx)
+ }
+
+ val (partTransforms, partCols, bucketSpec, properties, options, location, comment, serdeInfo) =
+ visitCreateTableClauses(ctx.createTableClauses())
+ val columns = Option(ctx.colTypeList()).map(visitColTypeList).getOrElse(Nil)
+ val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText)
+
+ if (provider.isDefined && serdeInfo.isDefined) {
+ operationNotAllowed(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx)
+ }
+
+ val partitioning = partitionExpressions(partTransforms, partCols, ctx)
+
+ Option(ctx.query).map(plan) match {
+ case Some(_) if columns.nonEmpty =>
+ operationNotAllowed(
+ "Schema may not be specified in a Replace Table As Select (RTAS) statement",
+ ctx)
+
+ case Some(_) if partCols.nonEmpty =>
+ // non-reference partition columns are not allowed because schema can't be specified
+ operationNotAllowed(
+ "Partition column types may not be specified in Replace Table As Select (RTAS)",
+ ctx)
+
+ case Some(query) =>
+ ReplaceTableAsSelectStatement(table, query, partitioning, bucketSpec, properties,
+ provider, options, location, comment, writeOptions = Map.empty, serdeInfo,
+ orCreate = orCreate)
+
+ case _ =>
+ // Note: table schema includes both the table columns list and the partition columns
+ // with data type.
+ val schema = StructType(columns ++ partCols)
+ ReplaceTableStatement(table, schema, partitioning, bucketSpec, properties, provider,
+ options, location, comment, serdeInfo, orCreate = orCreate)
+ }
+ }
+
+ /**
+ * Parse new column info from ADD COLUMN into a QualifiedColType.
+ */
+ override def visitQualifiedColTypeWithPosition(
+ ctx: QualifiedColTypeWithPositionContext): QualifiedColType = withOrigin(ctx) {
+ QualifiedColType(
+ name = typedVisit[Seq[String]](ctx.name),
+ dataType = typedVisit[DataType](ctx.dataType),
+ nullable = ctx.NULL == null,
+ comment = Option(ctx.commentSpec()).map(visitCommentSpec),
+ position = Option(ctx.colPosition).map(typedVisit[ColumnPosition]))
+ }
+
+ /**
+ * Create or replace a view. This creates a [[CreateViewStatement]]
+ *
+ * For example:
+ * {{{
+ * CREATE [OR REPLACE] [[GLOBAL] TEMPORARY] VIEW [IF NOT EXISTS] multi_part_name
+ * [(column_name [COMMENT column_comment], ...) ]
+ * create_view_clauses
+ *
+ * AS SELECT ...;
+ *
+ * create_view_clauses (order insensitive):
+ * [COMMENT view_comment]
+ * [TBLPROPERTIES (property_name = property_value, ...)]
+ * }}}
+ */
+ override def visitCreateView(ctx: CreateViewContext): LogicalPlan = withOrigin(ctx) {
+ if (!ctx.identifierList.isEmpty) {
+ operationNotAllowed("CREATE VIEW ... PARTITIONED ON", ctx)
+ }
+
+ checkDuplicateClauses(ctx.commentSpec(), "COMMENT", ctx)
+ checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED ON", ctx)
+ checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx)
+
+ val userSpecifiedColumns = Option(ctx.identifierCommentList).toSeq.flatMap { icl =>
+ icl.identifierComment.asScala.map { ic =>
+ ic.identifier.getText -> Option(ic.commentSpec()).map(visitCommentSpec)
+ }
+ }
+
+ val properties = ctx.tablePropertyList.asScala.headOption.map(visitPropertyKeyValues)
+ .getOrElse(Map.empty)
+ if (ctx.TEMPORARY != null && properties.nonEmpty) {
+ operationNotAllowed("TBLPROPERTIES can't coexist with CREATE TEMPORARY VIEW", ctx)
+ }
+
+ val viewType = if (ctx.TEMPORARY == null) {
+ PersistedView
+ } else if (ctx.GLOBAL != null) {
+ GlobalTempView
+ } else {
+ LocalTempView
+ }
+ CreateViewStatement(
+ visitMultipartIdentifier(ctx.multipartIdentifier),
+ userSpecifiedColumns,
+ visitCommentSpecList(ctx.commentSpec()),
+ properties,
+ Option(source(ctx.query)),
+ plan(ctx.query),
+ ctx.EXISTS != null,
+ ctx.REPLACE != null,
+ viewType)
+ }
+
+ /**
+ * Alter the query of a view. This creates a [[AlterViewAsStatement]]
+ *
+ * For example:
+ * {{{
+ * ALTER VIEW multi_part_name AS SELECT ...;
+ * }}}
+ */
+ override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) {
+ AlterViewAsStatement(
+ visitMultipartIdentifier(ctx.multipartIdentifier),
+ originalText = source(ctx.query),
+ query = plan(ctx.query))
+ }
+}
+
+/**
+ * A container for holding named common table expressions (CTEs) and a query plan.
+ * This operator will be removed during analysis and the relations will be substituted into child.
+ *
+ * @param child The final query of this CTE.
+ * @param cteRelations A sequence of pair (alias, the CTE definition) that this CTE defined
+ * Each CTE can see the base tables and the previously defined CTEs only.
+ */
+case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output
+
+ override def simpleString(maxFields: Int): String = {
+ val cteAliases = truncatedString(cteRelations.map(_._1), "[", ", ", "]", maxFields)
+ s"CTE $cteAliases"
+ }
+
+ override def innerChildren: Seq[LogicalPlan] = cteRelations.map(_._2)
+
+ def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = this
+}
+
diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_1ExtendedSqlParser.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_1ExtendedSqlParser.scala
new file mode 100644
index 0000000000000..0d07ad12f1025
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_1ExtendedSqlParser.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.parser
+
+import org.antlr.v4.runtime._
+import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException}
+import org.antlr.v4.runtime.tree.TerminalNodeImpl
+import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser.{NonReservedContext, QuotedIdentifierContext}
+import org.apache.hudi.spark.sql.parser.{HoodieSqlBaseBaseListener, HoodieSqlBaseLexer, HoodieSqlBaseParser}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.parser.{ParseErrorListener, ParseException, ParserInterface}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.trees.Origin
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{AnalysisException, SparkSession}
+
+class HoodieSpark3_1ExtendedSqlParser(session: SparkSession, delegate: ParserInterface)
+ extends ParserInterface with Logging {
+
+ private lazy val conf = session.sqlContext.conf
+ private lazy val builder = new HoodieSpark3_1ExtendedSqlAstBuilder(conf, delegate)
+
+ override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
+ builder.visit(parser.singleStatement()) match {
+ case plan: LogicalPlan => plan
+ case _=> delegate.parsePlan(sqlText)
+ }
+ }
+
+ override def parseExpression(sqlText: String): Expression = delegate.parseExpression(sqlText)
+
+ override def parseTableIdentifier(sqlText: String): TableIdentifier =
+ delegate.parseTableIdentifier(sqlText)
+
+ override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier =
+ delegate.parseFunctionIdentifier(sqlText)
+
+ override def parseTableSchema(sqlText: String): StructType = delegate.parseTableSchema(sqlText)
+
+ override def parseDataType(sqlText: String): DataType = delegate.parseDataType(sqlText)
+
+ protected def parse[T](command: String)(toResult: HoodieSqlBaseParser => T): T = {
+ logDebug(s"Parsing command: $command")
+
+ val lexer = new HoodieSqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
+ lexer.removeErrorListeners()
+ lexer.addErrorListener(ParseErrorListener)
+
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new HoodieSqlBaseParser(tokenStream)
+ parser.addParseListener(PostProcessor)
+ parser.removeErrorListeners()
+ parser.addErrorListener(ParseErrorListener)
+// parser.legacy_setops_precedence_enabled = conf.setOpsPrecedenceEnforced
+ parser.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled
+ parser.SQL_standard_keyword_behavior = conf.ansiEnabled
+
+ try {
+ try {
+ // first, try parsing with potentially faster SLL mode
+ parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
+ toResult(parser)
+ }
+ catch {
+ case e: ParseCancellationException =>
+ // if we fail, parse with LL mode
+ tokenStream.seek(0) // rewind input stream
+ parser.reset()
+
+ // Try Again.
+ parser.getInterpreter.setPredictionMode(PredictionMode.LL)
+ toResult(parser)
+ }
+ }
+ catch {
+ case e: ParseException if e.command.isDefined =>
+ throw e
+ case e: ParseException =>
+ throw e.withCommand(command)
+ case e: AnalysisException =>
+ val position = Origin(e.line, e.startPosition)
+ throw new ParseException(Option(command), e.message, position, position)
+ }
+ }
+
+ override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
+ delegate.parseMultipartIdentifier(sqlText)
+ }
+}
+
+/**
+ * Fork from `org.apache.spark.sql.catalyst.parser.UpperCaseCharStream`.
+ */
+class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream {
+ override def consume(): Unit = wrapped.consume
+ override def getSourceName(): String = wrapped.getSourceName
+ override def index(): Int = wrapped.index
+ override def mark(): Int = wrapped.mark
+ override def release(marker: Int): Unit = wrapped.release(marker)
+ override def seek(where: Int): Unit = wrapped.seek(where)
+ override def size(): Int = wrapped.size
+
+ override def getText(interval: Interval): String = {
+ // ANTLR 4.7's CodePointCharStream implementations have bugs when
+ // getText() is called with an empty stream, or intervals where
+ // the start > end. See
+ // https://github.com/antlr/antlr4/commit/ac9f7530 for one fix
+ // that is not yet in a released ANTLR artifact.
+ if (size() > 0 && (interval.b - interval.a >= 0)) {
+ wrapped.getText(interval)
+ } else {
+ ""
+ }
+ }
+ // scalastyle:off
+ override def LA(i: Int): Int = {
+ // scalastyle:on
+ val la = wrapped.LA(i)
+ if (la == 0 || la == IntStream.EOF) la
+ else Character.toUpperCase(la)
+ }
+}
+
+/**
+ * Fork from `org.apache.spark.sql.catalyst.parser.PostProcessor`.
+ */
+case object PostProcessor extends HoodieSqlBaseBaseListener {
+
+ /** Remove the back ticks from an Identifier. */
+ override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = {
+ replaceTokenByIdentifier(ctx, 1) { token =>
+ // Remove the double back ticks in the string.
+ token.setText(token.getText.replace("``", "`"))
+ token
+ }
+ }
+
+ /** Treat non-reserved keywords as Identifiers. */
+ override def exitNonReserved(ctx: NonReservedContext): Unit = {
+ replaceTokenByIdentifier(ctx, 0)(identity)
+ }
+
+ private def replaceTokenByIdentifier(
+ ctx: ParserRuleContext,
+ stripMargins: Int)(
+ f: CommonToken => CommonToken = identity): Unit = {
+ val parent = ctx.getParent
+ parent.removeLastChild()
+ val token = ctx.getChild(0).getPayload.asInstanceOf[Token]
+ val newToken = new CommonToken(
+ new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
+ HoodieSqlBaseParser.IDENTIFIER,
+ token.getChannel,
+ token.getStartIndex + stripMargins,
+ token.getStopIndex - stripMargins)
+ parent.addChild(new TerminalNodeImpl(f(newToken)))
+ }
+}
diff --git a/hudi-spark-datasource/hudi-spark3/pom.xml b/hudi-spark-datasource/hudi-spark3/pom.xml
index d8dba8384886c..1900f82474999 100644
--- a/hudi-spark-datasource/hudi-spark3/pom.xml
+++ b/hudi-spark-datasource/hudi-spark3/pom.xml
@@ -144,6 +144,24 @@
org.jacoco
jacoco-maven-plugin
+
+ org.antlr
+ antlr4-maven-plugin
+ ${antlr.version}
+
+
+
+ antlr4
+
+
+
+
+ true
+ true
+ ../hudi-spark3/src/main/antlr4
+ ../hudi-spark3/src/main/antlr4/imports
+
+
@@ -157,7 +175,7 @@
org.apache.spark
spark-sql_2.12
- ${spark3.version}
+ ${spark3.2.version}
true
diff --git a/hudi-spark-datasource/hudi-spark3/src/main/antlr4/imports/SqlBase.g4 b/hudi-spark-datasource/hudi-spark3/src/main/antlr4/imports/SqlBase.g4
new file mode 100644
index 0000000000000..d4e1e48351ccc
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3/src/main/antlr4/imports/SqlBase.g4
@@ -0,0 +1,1908 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar.
+ */
+
+// The parser file is forked from spark 3.2.0's SqlBase.g4.
+grammar SqlBase;
+
+@parser::members {
+ /**
+ * When false, INTERSECT is given the greater precedence over the other set
+ * operations (UNION, EXCEPT and MINUS) as per the SQL standard.
+ */
+ public boolean legacy_setops_precedence_enabled = false;
+
+ /**
+ * When false, a literal with an exponent would be converted into
+ * double type rather than decimal type.
+ */
+ public boolean legacy_exponent_literal_as_decimal_enabled = false;
+
+ /**
+ * When true, the behavior of keywords follows ANSI SQL standard.
+ */
+ public boolean SQL_standard_keyword_behavior = false;
+}
+
+@lexer::members {
+ /**
+ * Verify whether current token is a valid decimal token (which contains dot).
+ * Returns true if the character that follows the token is not a digit or letter or underscore.
+ *
+ * For example:
+ * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'.
+ * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'.
+ * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'.
+ * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed
+ * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+'
+ * which is not a digit or letter or underscore.
+ */
+ public boolean isValidDecimal() {
+ int nextChar = _input.LA(1);
+ if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' ||
+ nextChar == '_') {
+ return false;
+ } else {
+ return true;
+ }
+ }
+
+ /**
+ * This method will be called when we see '/*' and try to match it as a bracketed comment.
+ * If the next character is '+', it should be parsed as hint later, and we cannot match
+ * it as a bracketed comment.
+ *
+ * Returns true if the next character is '+'.
+ */
+ public boolean isHint() {
+ int nextChar = _input.LA(1);
+ if (nextChar == '+') {
+ return true;
+ } else {
+ return false;
+ }
+ }
+}
+
+singleStatement
+ : statement ';'* EOF
+ ;
+
+singleExpression
+ : namedExpression EOF
+ ;
+
+singleTableIdentifier
+ : tableIdentifier EOF
+ ;
+
+singleMultipartIdentifier
+ : multipartIdentifier EOF
+ ;
+
+singleFunctionIdentifier
+ : functionIdentifier EOF
+ ;
+
+singleDataType
+ : dataType EOF
+ ;
+
+singleTableSchema
+ : colTypeList EOF
+ ;
+
+statement
+ : query #statementDefault
+ | ctes? dmlStatementNoWith #dmlStatement
+ | USE NAMESPACE? multipartIdentifier #use
+ | CREATE namespace (IF NOT EXISTS)? multipartIdentifier
+ (commentSpec |
+ locationSpec |
+ (WITH (DBPROPERTIES | PROPERTIES) tablePropertyList))* #createNamespace
+ | ALTER namespace multipartIdentifier
+ SET (DBPROPERTIES | PROPERTIES) tablePropertyList #setNamespaceProperties
+ | ALTER namespace multipartIdentifier
+ SET locationSpec #setNamespaceLocation
+ | DROP namespace (IF EXISTS)? multipartIdentifier
+ (RESTRICT | CASCADE)? #dropNamespace
+ | SHOW (DATABASES | NAMESPACES) ((FROM | IN) multipartIdentifier)?
+ (LIKE? pattern=STRING)? #showNamespaces
+ | createTableHeader ('(' colTypeList ')')? tableProvider?
+ createTableClauses
+ (AS? query)? #createTable
+ | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier
+ LIKE source=tableIdentifier
+ (tableProvider |
+ rowFormat |
+ createFileFormat |
+ locationSpec |
+ (TBLPROPERTIES tableProps=tablePropertyList))* #createTableLike
+ | replaceTableHeader ('(' colTypeList ')')? tableProvider?
+ createTableClauses
+ (AS? query)? #replaceTable
+ | ANALYZE TABLE multipartIdentifier partitionSpec? COMPUTE STATISTICS
+ (identifier | FOR COLUMNS identifierSeq | FOR ALL COLUMNS)? #analyze
+ | ANALYZE TABLES ((FROM | IN) multipartIdentifier)? COMPUTE STATISTICS
+ (identifier)? #analyzeTables
+ | ALTER TABLE multipartIdentifier
+ ADD (COLUMN | COLUMNS)
+ columns=qualifiedColTypeWithPositionList #addTableColumns
+ | ALTER TABLE multipartIdentifier
+ ADD (COLUMN | COLUMNS)
+ '(' columns=qualifiedColTypeWithPositionList ')' #addTableColumns
+ | ALTER TABLE table=multipartIdentifier
+ RENAME COLUMN
+ from=multipartIdentifier TO to=errorCapturingIdentifier #renameTableColumn
+ | ALTER TABLE multipartIdentifier
+ DROP (COLUMN | COLUMNS)
+ '(' columns=multipartIdentifierList ')' #dropTableColumns
+ | ALTER TABLE multipartIdentifier
+ DROP (COLUMN | COLUMNS) columns=multipartIdentifierList #dropTableColumns
+ | ALTER (TABLE | VIEW) from=multipartIdentifier
+ RENAME TO to=multipartIdentifier #renameTable
+ | ALTER (TABLE | VIEW) multipartIdentifier
+ SET TBLPROPERTIES tablePropertyList #setTableProperties
+ | ALTER (TABLE | VIEW) multipartIdentifier
+ UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties
+ | ALTER TABLE table=multipartIdentifier
+ (ALTER | CHANGE) COLUMN? column=multipartIdentifier
+ alterColumnAction? #alterTableAlterColumn
+ | ALTER TABLE table=multipartIdentifier partitionSpec?
+ CHANGE COLUMN?
+ colName=multipartIdentifier colType colPosition? #hiveChangeColumn
+ | ALTER TABLE table=multipartIdentifier partitionSpec?
+ REPLACE COLUMNS
+ '(' columns=qualifiedColTypeWithPositionList ')' #hiveReplaceColumns
+ | ALTER TABLE multipartIdentifier (partitionSpec)?
+ SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe
+ | ALTER TABLE multipartIdentifier (partitionSpec)?
+ SET SERDEPROPERTIES tablePropertyList #setTableSerDe
+ | ALTER (TABLE | VIEW) multipartIdentifier ADD (IF NOT EXISTS)?
+ partitionSpecLocation+ #addTablePartition
+ | ALTER TABLE multipartIdentifier
+ from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition
+ | ALTER (TABLE | VIEW) multipartIdentifier
+ DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions
+ | ALTER TABLE multipartIdentifier
+ (partitionSpec)? SET locationSpec #setTableLocation
+ | ALTER TABLE multipartIdentifier RECOVER PARTITIONS #recoverPartitions
+ | DROP TABLE (IF EXISTS)? multipartIdentifier PURGE? #dropTable
+ | DROP VIEW (IF EXISTS)? multipartIdentifier #dropView
+ | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)?
+ VIEW (IF NOT EXISTS)? multipartIdentifier
+ identifierCommentList?
+ (commentSpec |
+ (PARTITIONED ON identifierList) |
+ (TBLPROPERTIES tablePropertyList))*
+ AS query #createView
+ | CREATE (OR REPLACE)? GLOBAL? TEMPORARY VIEW
+ tableIdentifier ('(' colTypeList ')')? tableProvider
+ (OPTIONS tablePropertyList)? #createTempViewUsing
+ | ALTER VIEW multipartIdentifier AS? query #alterViewQuery
+ | CREATE (OR REPLACE)? TEMPORARY? FUNCTION (IF NOT EXISTS)?
+ multipartIdentifier AS className=STRING
+ (USING resource (',' resource)*)? #createFunction
+ | DROP TEMPORARY? FUNCTION (IF EXISTS)? multipartIdentifier #dropFunction
+ | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)?
+ statement #explain
+ | SHOW TABLES ((FROM | IN) multipartIdentifier)?
+ (LIKE? pattern=STRING)? #showTables
+ | SHOW TABLE EXTENDED ((FROM | IN) ns=multipartIdentifier)?
+ LIKE pattern=STRING partitionSpec? #showTableExtended
+ | SHOW TBLPROPERTIES table=multipartIdentifier
+ ('(' key=tablePropertyKey ')')? #showTblProperties
+ | SHOW COLUMNS (FROM | IN) table=multipartIdentifier
+ ((FROM | IN) ns=multipartIdentifier)? #showColumns
+ | SHOW VIEWS ((FROM | IN) multipartIdentifier)?
+ (LIKE? pattern=STRING)? #showViews
+ | SHOW PARTITIONS multipartIdentifier partitionSpec? #showPartitions
+ | SHOW identifier? FUNCTIONS
+ (LIKE? (multipartIdentifier | pattern=STRING))? #showFunctions
+ | SHOW CREATE TABLE multipartIdentifier (AS SERDE)? #showCreateTable
+ | SHOW CURRENT NAMESPACE #showCurrentNamespace
+ | (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction
+ | (DESC | DESCRIBE) namespace EXTENDED?
+ multipartIdentifier #describeNamespace
+ | (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)?
+ multipartIdentifier partitionSpec? describeColName? #describeRelation
+ | (DESC | DESCRIBE) QUERY? query #describeQuery
+ | COMMENT ON namespace multipartIdentifier IS
+ comment=(STRING | NULL) #commentNamespace
+ | COMMENT ON TABLE multipartIdentifier IS comment=(STRING | NULL) #commentTable
+ | REFRESH TABLE multipartIdentifier #refreshTable
+ | REFRESH FUNCTION multipartIdentifier #refreshFunction
+ | REFRESH (STRING | .*?) #refreshResource
+ | CACHE LAZY? TABLE multipartIdentifier
+ (OPTIONS options=tablePropertyList)? (AS? query)? #cacheTable
+ | UNCACHE TABLE (IF EXISTS)? multipartIdentifier #uncacheTable
+ | CLEAR CACHE #clearCache
+ | LOAD DATA LOCAL? INPATH path=STRING OVERWRITE? INTO TABLE
+ multipartIdentifier partitionSpec? #loadData
+ | TRUNCATE TABLE multipartIdentifier partitionSpec? #truncateTable
+ | MSCK REPAIR TABLE multipartIdentifier
+ (option=(ADD|DROP|SYNC) PARTITIONS)? #repairTable
+ | op=(ADD | LIST) identifier .*? #manageResource
+ | SET ROLE .*? #failNativeCommand
+ | SET TIME ZONE interval #setTimeZone
+ | SET TIME ZONE timezone=(STRING | LOCAL) #setTimeZone
+ | SET TIME ZONE .*? #setTimeZone
+ | SET configKey EQ configValue #setQuotedConfiguration
+ | SET configKey (EQ .*?)? #setQuotedConfiguration
+ | SET .*? EQ configValue #setQuotedConfiguration
+ | SET .*? #setConfiguration
+ | RESET configKey #resetQuotedConfiguration
+ | RESET .*? #resetConfiguration
+ | unsupportedHiveNativeCommands .*? #failNativeCommand
+ ;
+
+configKey
+ : quotedIdentifier
+ ;
+
+configValue
+ : quotedIdentifier
+ ;
+
+unsupportedHiveNativeCommands
+ : kw1=CREATE kw2=ROLE
+ | kw1=DROP kw2=ROLE
+ | kw1=GRANT kw2=ROLE?
+ | kw1=REVOKE kw2=ROLE?
+ | kw1=SHOW kw2=GRANT
+ | kw1=SHOW kw2=ROLE kw3=GRANT?
+ | kw1=SHOW kw2=PRINCIPALS
+ | kw1=SHOW kw2=ROLES
+ | kw1=SHOW kw2=CURRENT kw3=ROLES
+ | kw1=EXPORT kw2=TABLE
+ | kw1=IMPORT kw2=TABLE
+ | kw1=SHOW kw2=COMPACTIONS
+ | kw1=SHOW kw2=CREATE kw3=TABLE
+ | kw1=SHOW kw2=TRANSACTIONS
+ | kw1=SHOW kw2=INDEXES
+ | kw1=SHOW kw2=LOCKS
+ | kw1=CREATE kw2=INDEX
+ | kw1=DROP kw2=INDEX
+ | kw1=ALTER kw2=INDEX
+ | kw1=LOCK kw2=TABLE
+ | kw1=LOCK kw2=DATABASE
+ | kw1=UNLOCK kw2=TABLE
+ | kw1=UNLOCK kw2=DATABASE
+ | kw1=CREATE kw2=TEMPORARY kw3=MACRO
+ | kw1=DROP kw2=TEMPORARY kw3=MACRO
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=CLUSTERED
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=CLUSTERED kw4=BY
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SORTED
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=SKEWED kw4=BY
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SKEWED
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=STORED kw5=AS kw6=DIRECTORIES
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=SET kw4=SKEWED kw5=LOCATION
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=EXCHANGE kw4=PARTITION
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=ARCHIVE kw4=PARTITION
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=UNARCHIVE kw4=PARTITION
+ | kw1=ALTER kw2=TABLE tableIdentifier kw3=TOUCH
+ | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=COMPACT
+ | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=CONCATENATE
+ | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=SET kw4=FILEFORMAT
+ | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=REPLACE kw4=COLUMNS
+ | kw1=START kw2=TRANSACTION
+ | kw1=COMMIT
+ | kw1=ROLLBACK
+ | kw1=DFS
+ ;
+
+createTableHeader
+ : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? multipartIdentifier
+ ;
+
+replaceTableHeader
+ : (CREATE OR)? REPLACE TABLE multipartIdentifier
+ ;
+
+bucketSpec
+ : CLUSTERED BY identifierList
+ (SORTED BY orderedIdentifierList)?
+ INTO INTEGER_VALUE BUCKETS
+ ;
+
+skewSpec
+ : SKEWED BY identifierList
+ ON (constantList | nestedConstantList)
+ (STORED AS DIRECTORIES)?
+ ;
+
+locationSpec
+ : LOCATION STRING
+ ;
+
+commentSpec
+ : COMMENT STRING
+ ;
+
+query
+ : ctes? queryTerm queryOrganization
+ ;
+
+insertInto
+ : INSERT OVERWRITE TABLE? multipartIdentifier (partitionSpec (IF NOT EXISTS)?)? identifierList? #insertOverwriteTable
+ | INSERT INTO TABLE? multipartIdentifier partitionSpec? (IF NOT EXISTS)? identifierList? #insertIntoTable
+ | INSERT OVERWRITE LOCAL? DIRECTORY path=STRING rowFormat? createFileFormat? #insertOverwriteHiveDir
+ | INSERT OVERWRITE LOCAL? DIRECTORY (path=STRING)? tableProvider (OPTIONS options=tablePropertyList)? #insertOverwriteDir
+ ;
+
+partitionSpecLocation
+ : partitionSpec locationSpec?
+ ;
+
+partitionSpec
+ : PARTITION '(' partitionVal (',' partitionVal)* ')'
+ ;
+
+partitionVal
+ : identifier (EQ constant)?
+ ;
+
+namespace
+ : NAMESPACE
+ | DATABASE
+ | SCHEMA
+ ;
+
+describeFuncName
+ : qualifiedName
+ | STRING
+ | comparisonOperator
+ | arithmeticOperator
+ | predicateOperator
+ ;
+
+describeColName
+ : nameParts+=identifier ('.' nameParts+=identifier)*
+ ;
+
+ctes
+ : WITH namedQuery (',' namedQuery)*
+ ;
+
+namedQuery
+ : name=errorCapturingIdentifier (columnAliases=identifierList)? AS? '(' query ')'
+ ;
+
+tableProvider
+ : USING multipartIdentifier
+ ;
+
+createTableClauses
+ :((OPTIONS options=tablePropertyList) |
+ (PARTITIONED BY partitioning=partitionFieldList) |
+ skewSpec |
+ bucketSpec |
+ rowFormat |
+ createFileFormat |
+ locationSpec |
+ commentSpec |
+ (TBLPROPERTIES tableProps=tablePropertyList))*
+ ;
+
+tablePropertyList
+ : '(' tableProperty (',' tableProperty)* ')'
+ ;
+
+tableProperty
+ : key=tablePropertyKey (EQ? value=tablePropertyValue)?
+ ;
+
+tablePropertyKey
+ : identifier ('.' identifier)*
+ | STRING
+ ;
+
+tablePropertyValue
+ : INTEGER_VALUE
+ | DECIMAL_VALUE
+ | booleanValue
+ | STRING
+ ;
+
+constantList
+ : '(' constant (',' constant)* ')'
+ ;
+
+nestedConstantList
+ : '(' constantList (',' constantList)* ')'
+ ;
+
+createFileFormat
+ : STORED AS fileFormat
+ | STORED BY storageHandler
+ ;
+
+fileFormat
+ : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING #tableFileFormat
+ | identifier #genericFileFormat
+ ;
+
+storageHandler
+ : STRING (WITH SERDEPROPERTIES tablePropertyList)?
+ ;
+
+resource
+ : identifier STRING
+ ;
+
+dmlStatementNoWith
+ : insertInto queryTerm queryOrganization #singleInsertQuery
+ | fromClause multiInsertQueryBody+ #multiInsertQuery
+ | DELETE FROM multipartIdentifier tableAlias whereClause? #deleteFromTable
+ | UPDATE multipartIdentifier tableAlias setClause whereClause? #updateTable
+ | MERGE INTO target=multipartIdentifier targetAlias=tableAlias
+ USING (source=multipartIdentifier |
+ '(' sourceQuery=query')') sourceAlias=tableAlias
+ ON mergeCondition=booleanExpression
+ matchedClause*
+ notMatchedClause* #mergeIntoTable
+ ;
+
+queryOrganization
+ : (ORDER BY order+=sortItem (',' order+=sortItem)*)?
+ (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)?
+ (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)?
+ (SORT BY sort+=sortItem (',' sort+=sortItem)*)?
+ windowClause?
+ (LIMIT (ALL | limit=expression))?
+ ;
+
+multiInsertQueryBody
+ : insertInto fromStatementBody
+ ;
+
+queryTerm
+ : queryPrimary #queryTermDefault
+ | left=queryTerm {legacy_setops_precedence_enabled}?
+ operator=(INTERSECT | UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation
+ | left=queryTerm {!legacy_setops_precedence_enabled}?
+ operator=INTERSECT setQuantifier? right=queryTerm #setOperation
+ | left=queryTerm {!legacy_setops_precedence_enabled}?
+ operator=(UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation
+ ;
+
+queryPrimary
+ : querySpecification #queryPrimaryDefault
+ | fromStatement #fromStmt
+ | TABLE multipartIdentifier #table
+ | inlineTable #inlineTableDefault1
+ | '(' query ')' #subquery
+ ;
+
+sortItem
+ : expression ordering=(ASC | DESC)? (NULLS nullOrder=(LAST | FIRST))?
+ ;
+
+fromStatement
+ : fromClause fromStatementBody+
+ ;
+
+fromStatementBody
+ : transformClause
+ whereClause?
+ queryOrganization
+ | selectClause
+ lateralView*
+ whereClause?
+ aggregationClause?
+ havingClause?
+ windowClause?
+ queryOrganization
+ ;
+
+querySpecification
+ : transformClause
+ fromClause?
+ lateralView*
+ whereClause?
+ aggregationClause?
+ havingClause?
+ windowClause? #transformQuerySpecification
+ | selectClause
+ fromClause?
+ lateralView*
+ whereClause?
+ aggregationClause?
+ havingClause?
+ windowClause? #regularQuerySpecification
+ ;
+
+transformClause
+ : (SELECT kind=TRANSFORM '(' setQuantifier? expressionSeq ')'
+ | kind=MAP setQuantifier? expressionSeq
+ | kind=REDUCE setQuantifier? expressionSeq)
+ inRowFormat=rowFormat?
+ (RECORDWRITER recordWriter=STRING)?
+ USING script=STRING
+ (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))?
+ outRowFormat=rowFormat?
+ (RECORDREADER recordReader=STRING)?
+ ;
+
+selectClause
+ : SELECT (hints+=hint)* setQuantifier? namedExpressionSeq
+ ;
+
+setClause
+ : SET assignmentList
+ ;
+
+matchedClause
+ : WHEN MATCHED (AND matchedCond=booleanExpression)? THEN matchedAction
+ ;
+notMatchedClause
+ : WHEN NOT MATCHED (AND notMatchedCond=booleanExpression)? THEN notMatchedAction
+ ;
+
+matchedAction
+ : DELETE
+ | UPDATE SET ASTERISK
+ | UPDATE SET assignmentList
+ ;
+
+notMatchedAction
+ : INSERT ASTERISK
+ | INSERT '(' columns=multipartIdentifierList ')'
+ VALUES '(' expression (',' expression)* ')'
+ ;
+
+assignmentList
+ : assignment (',' assignment)*
+ ;
+
+assignment
+ : key=multipartIdentifier EQ value=expression
+ ;
+
+whereClause
+ : WHERE booleanExpression
+ ;
+
+havingClause
+ : HAVING booleanExpression
+ ;
+
+hint
+ : '/*+' hintStatements+=hintStatement (','? hintStatements+=hintStatement)* '*/'
+ ;
+
+hintStatement
+ : hintName=identifier
+ | hintName=identifier '(' parameters+=primaryExpression (',' parameters+=primaryExpression)* ')'
+ ;
+
+fromClause
+ : FROM relation (',' relation)* lateralView* pivotClause?
+ ;
+
+temporalClause
+ : FOR? (SYSTEM_TIME | TIMESTAMP) AS OF timestamp=valueExpression
+ | FOR? (SYSTEM_VERSION | VERSION) AS OF version=(INTEGER_VALUE | STRING)
+ ;
+
+aggregationClause
+ : GROUP BY groupingExpressionsWithGroupingAnalytics+=groupByClause
+ (',' groupingExpressionsWithGroupingAnalytics+=groupByClause)*
+ | GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* (
+ WITH kind=ROLLUP
+ | WITH kind=CUBE
+ | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')?
+ ;
+
+groupByClause
+ : groupingAnalytics
+ | expression
+ ;
+
+groupingAnalytics
+ : (ROLLUP | CUBE) '(' groupingSet (',' groupingSet)* ')'
+ | GROUPING SETS '(' groupingElement (',' groupingElement)* ')'
+ ;
+
+groupingElement
+ : groupingAnalytics
+ | groupingSet
+ ;
+
+groupingSet
+ : '(' (expression (',' expression)*)? ')'
+ | expression
+ ;
+
+pivotClause
+ : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn IN '(' pivotValues+=pivotValue (',' pivotValues+=pivotValue)* ')' ')'
+ ;
+
+pivotColumn
+ : identifiers+=identifier
+ | '(' identifiers+=identifier (',' identifiers+=identifier)* ')'
+ ;
+
+pivotValue
+ : expression (AS? identifier)?
+ ;
+
+lateralView
+ : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)?
+ ;
+
+setQuantifier
+ : DISTINCT
+ | ALL
+ ;
+
+relation
+ : LATERAL? relationPrimary joinRelation*
+ ;
+
+joinRelation
+ : (joinType) JOIN LATERAL? right=relationPrimary joinCriteria?
+ | NATURAL joinType JOIN LATERAL? right=relationPrimary
+ ;
+
+joinType
+ : INNER?
+ | CROSS
+ | LEFT OUTER?
+ | LEFT? SEMI
+ | RIGHT OUTER?
+ | FULL OUTER?
+ | LEFT? ANTI
+ ;
+
+joinCriteria
+ : ON booleanExpression
+ | USING identifierList
+ ;
+
+sample
+ : TABLESAMPLE '(' sampleMethod? ')'
+ ;
+
+sampleMethod
+ : negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) PERCENTLIT #sampleByPercentile
+ | expression ROWS #sampleByRows
+ | sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE
+ (ON (identifier | qualifiedName '(' ')'))? #sampleByBucket
+ | bytes=expression #sampleByBytes
+ ;
+
+identifierList
+ : '(' identifierSeq ')'
+ ;
+
+identifierSeq
+ : ident+=errorCapturingIdentifier (',' ident+=errorCapturingIdentifier)*
+ ;
+
+orderedIdentifierList
+ : '(' orderedIdentifier (',' orderedIdentifier)* ')'
+ ;
+
+orderedIdentifier
+ : ident=errorCapturingIdentifier ordering=(ASC | DESC)?
+ ;
+
+identifierCommentList
+ : '(' identifierComment (',' identifierComment)* ')'
+ ;
+
+identifierComment
+ : identifier commentSpec?
+ ;
+
+relationPrimary
+ : multipartIdentifier temporalClause?
+ sample? tableAlias #tableName
+ | '(' query ')' sample? tableAlias #aliasedQuery
+ | '(' relation ')' sample? tableAlias #aliasedRelation
+ | inlineTable #inlineTableDefault2
+ | functionTable #tableValuedFunction
+ ;
+
+inlineTable
+ : VALUES expression (',' expression)* tableAlias
+ ;
+
+functionTable
+ : funcName=functionName '(' (expression (',' expression)*)? ')' tableAlias
+ ;
+
+tableAlias
+ : (AS? strictIdentifier identifierList?)?
+ ;
+
+rowFormat
+ : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde
+ | ROW FORMAT DELIMITED
+ (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)?
+ (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)?
+ (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)?
+ (LINES TERMINATED BY linesSeparatedBy=STRING)?
+ (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited
+ ;
+
+multipartIdentifierList
+ : multipartIdentifier (',' multipartIdentifier)*
+ ;
+
+multipartIdentifier
+ : parts+=errorCapturingIdentifier ('.' parts+=errorCapturingIdentifier)*
+ ;
+
+tableIdentifier
+ : (db=errorCapturingIdentifier '.')? table=errorCapturingIdentifier
+ ;
+
+functionIdentifier
+ : (db=errorCapturingIdentifier '.')? function=errorCapturingIdentifier
+ ;
+
+namedExpression
+ : expression (AS? (name=errorCapturingIdentifier | identifierList))?
+ ;
+
+namedExpressionSeq
+ : namedExpression (',' namedExpression)*
+ ;
+
+partitionFieldList
+ : '(' fields+=partitionField (',' fields+=partitionField)* ')'
+ ;
+
+partitionField
+ : transform #partitionTransform
+ | colType #partitionColumn
+ ;
+
+transform
+ : qualifiedName #identityTransform
+ | transformName=identifier
+ '(' argument+=transformArgument (',' argument+=transformArgument)* ')' #applyTransform
+ ;
+
+transformArgument
+ : qualifiedName
+ | constant
+ ;
+
+expression
+ : booleanExpression
+ ;
+
+expressionSeq
+ : expression (',' expression)*
+ ;
+
+booleanExpression
+ : NOT booleanExpression #logicalNot
+ | EXISTS '(' query ')' #exists
+ | valueExpression predicate? #predicated
+ | left=booleanExpression operator=AND right=booleanExpression #logicalBinary
+ | left=booleanExpression operator=OR right=booleanExpression #logicalBinary
+ ;
+
+predicate
+ : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression
+ | NOT? kind=IN '(' expression (',' expression)* ')'
+ | NOT? kind=IN '(' query ')'
+ | NOT? kind=RLIKE pattern=valueExpression
+ | NOT? kind=LIKE quantifier=(ANY | SOME | ALL) ('('')' | '(' expression (',' expression)* ')')
+ | NOT? kind=LIKE pattern=valueExpression (ESCAPE escapeChar=STRING)?
+ | IS NOT? kind=NULL
+ | IS NOT? kind=(TRUE | FALSE | UNKNOWN)
+ | IS NOT? kind=DISTINCT FROM right=valueExpression
+ ;
+
+valueExpression
+ : primaryExpression #valueExpressionDefault
+ | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary
+ | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=(PLUS | MINUS | CONCAT_PIPE) right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary
+ | left=valueExpression comparisonOperator right=valueExpression #comparison
+ ;
+
+primaryExpression
+ : name=(CURRENT_DATE | CURRENT_TIMESTAMP | CURRENT_USER) #currentLike
+ | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
+ | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
+ | name=(CAST | TRY_CAST) '(' expression AS dataType ')' #cast
+ | STRUCT '(' (argument+=namedExpression (',' argument+=namedExpression)*)? ')' #struct
+ | FIRST '(' expression (IGNORE NULLS)? ')' #first
+ | LAST '(' expression (IGNORE NULLS)? ')' #last
+ | POSITION '(' substr=valueExpression IN str=valueExpression ')' #position
+ | constant #constantDefault
+ | ASTERISK #star
+ | qualifiedName '.' ASTERISK #star
+ | '(' namedExpression (',' namedExpression)+ ')' #rowConstructor
+ | '(' query ')' #subqueryExpression
+ | functionName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')'
+ (FILTER '(' WHERE where=booleanExpression ')')?
+ (nullsOption=(IGNORE | RESPECT) NULLS)? ( OVER windowSpec)? #functionCall
+ | identifier '->' expression #lambda
+ | '(' identifier (',' identifier)+ ')' '->' expression #lambda
+ | value=primaryExpression '[' index=valueExpression ']' #subscript
+ | identifier #columnReference
+ | base=primaryExpression '.' fieldName=identifier #dereference
+ | '(' expression ')' #parenthesizedExpression
+ | EXTRACT '(' field=identifier FROM source=valueExpression ')' #extract
+ | (SUBSTR | SUBSTRING) '(' str=valueExpression (FROM | ',') pos=valueExpression
+ ((FOR | ',') len=valueExpression)? ')' #substring
+ | TRIM '(' trimOption=(BOTH | LEADING | TRAILING)? (trimStr=valueExpression)?
+ FROM srcStr=valueExpression ')' #trim
+ | OVERLAY '(' input=valueExpression PLACING replace=valueExpression
+ FROM position=valueExpression (FOR length=valueExpression)? ')' #overlay
+ ;
+
+constant
+ : NULL #nullLiteral
+ | interval #intervalLiteral
+ | identifier STRING #typeConstructor
+ | number #numericLiteral
+ | booleanValue #booleanLiteral
+ | STRING+ #stringLiteral
+ ;
+
+comparisonOperator
+ : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ
+ ;
+
+arithmeticOperator
+ : PLUS | MINUS | ASTERISK | SLASH | PERCENT | DIV | TILDE | AMPERSAND | PIPE | CONCAT_PIPE | HAT
+ ;
+
+predicateOperator
+ : OR | AND | IN | NOT
+ ;
+
+booleanValue
+ : TRUE | FALSE
+ ;
+
+interval
+ : INTERVAL (errorCapturingMultiUnitsInterval | errorCapturingUnitToUnitInterval)?
+ ;
+
+errorCapturingMultiUnitsInterval
+ : body=multiUnitsInterval unitToUnitInterval?
+ ;
+
+multiUnitsInterval
+ : (intervalValue unit+=identifier)+
+ ;
+
+errorCapturingUnitToUnitInterval
+ : body=unitToUnitInterval (error1=multiUnitsInterval | error2=unitToUnitInterval)?
+ ;
+
+unitToUnitInterval
+ : value=intervalValue from=identifier TO to=identifier
+ ;
+
+intervalValue
+ : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE | STRING)
+ ;
+
+colPosition
+ : position=FIRST | position=AFTER afterCol=errorCapturingIdentifier
+ ;
+
+dataType
+ : complex=ARRAY '<' dataType '>' #complexDataType
+ | complex=MAP '<' dataType ',' dataType '>' #complexDataType
+ | complex=STRUCT ('<' complexColTypeList? '>' | NEQ) #complexDataType
+ | INTERVAL from=(YEAR | MONTH) (TO to=MONTH)? #yearMonthIntervalDataType
+ | INTERVAL from=(DAY | HOUR | MINUTE | SECOND)
+ (TO to=(HOUR | MINUTE | SECOND))? #dayTimeIntervalDataType
+ | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType
+ ;
+
+qualifiedColTypeWithPositionList
+ : qualifiedColTypeWithPosition (',' qualifiedColTypeWithPosition)*
+ ;
+
+qualifiedColTypeWithPosition
+ : name=multipartIdentifier dataType (NOT NULL)? commentSpec? colPosition?
+ ;
+
+colTypeList
+ : colType (',' colType)*
+ ;
+
+colType
+ : colName=errorCapturingIdentifier dataType (NOT NULL)? commentSpec?
+ ;
+
+complexColTypeList
+ : complexColType (',' complexColType)*
+ ;
+
+complexColType
+ : identifier ':'? dataType (NOT NULL)? commentSpec?
+ ;
+
+whenClause
+ : WHEN condition=expression THEN result=expression
+ ;
+
+windowClause
+ : WINDOW namedWindow (',' namedWindow)*
+ ;
+
+namedWindow
+ : name=errorCapturingIdentifier AS windowSpec
+ ;
+
+windowSpec
+ : name=errorCapturingIdentifier #windowRef
+ | '('name=errorCapturingIdentifier')' #windowRef
+ | '('
+ ( CLUSTER BY partition+=expression (',' partition+=expression)*
+ | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)?
+ ((ORDER | SORT) BY sortItem (',' sortItem)*)?)
+ windowFrame?
+ ')' #windowDef
+ ;
+
+windowFrame
+ : frameType=RANGE start=frameBound
+ | frameType=ROWS start=frameBound
+ | frameType=RANGE BETWEEN start=frameBound AND end=frameBound
+ | frameType=ROWS BETWEEN start=frameBound AND end=frameBound
+ ;
+
+frameBound
+ : UNBOUNDED boundType=(PRECEDING | FOLLOWING)
+ | boundType=CURRENT ROW
+ | expression boundType=(PRECEDING | FOLLOWING)
+ ;
+
+qualifiedNameList
+ : qualifiedName (',' qualifiedName)*
+ ;
+
+functionName
+ : qualifiedName
+ | FILTER
+ | LEFT
+ | RIGHT
+ ;
+
+qualifiedName
+ : identifier ('.' identifier)*
+ ;
+
+// this rule is used for explicitly capturing wrong identifiers such as test-table, which should actually be `test-table`
+// replace identifier with errorCapturingIdentifier where the immediate follow symbol is not an expression, otherwise
+// valid expressions such as "a-b" can be recognized as an identifier
+errorCapturingIdentifier
+ : identifier errorCapturingIdentifierExtra
+ ;
+
+// extra left-factoring grammar
+errorCapturingIdentifierExtra
+ : (MINUS identifier)+ #errorIdent
+ | #realIdent
+ ;
+
+identifier
+ : strictIdentifier
+ | {!SQL_standard_keyword_behavior}? strictNonReserved
+ ;
+
+strictIdentifier
+ : IDENTIFIER #unquotedIdentifier
+ | quotedIdentifier #quotedIdentifierAlternative
+ | {SQL_standard_keyword_behavior}? ansiNonReserved #unquotedIdentifier
+ | {!SQL_standard_keyword_behavior}? nonReserved #unquotedIdentifier
+ ;
+
+quotedIdentifier
+ : BACKQUOTED_IDENTIFIER
+ ;
+
+number
+ : {!legacy_exponent_literal_as_decimal_enabled}? MINUS? EXPONENT_VALUE #exponentLiteral
+ | {!legacy_exponent_literal_as_decimal_enabled}? MINUS? DECIMAL_VALUE #decimalLiteral
+ | {legacy_exponent_literal_as_decimal_enabled}? MINUS? (EXPONENT_VALUE | DECIMAL_VALUE) #legacyDecimalLiteral
+ | MINUS? INTEGER_VALUE #integerLiteral
+ | MINUS? BIGINT_LITERAL #bigIntLiteral
+ | MINUS? SMALLINT_LITERAL #smallIntLiteral
+ | MINUS? TINYINT_LITERAL #tinyIntLiteral
+ | MINUS? DOUBLE_LITERAL #doubleLiteral
+ | MINUS? FLOAT_LITERAL #floatLiteral
+ | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral
+ ;
+
+alterColumnAction
+ : TYPE dataType
+ | commentSpec
+ | colPosition
+ | setOrDrop=(SET | DROP) NOT NULL
+ ;
+
+// When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL.
+// - Reserved keywords:
+// Keywords that are reserved and can't be used as identifiers for table, view, column,
+// function, alias, etc.
+// - Non-reserved keywords:
+// Keywords that have a special meaning only in particular contexts and can be used as
+// identifiers in other contexts. For example, `EXPLAIN SELECT ...` is a command, but EXPLAIN
+// can be used as identifiers in other places.
+// You can find the full keywords list by searching "Start of the keywords list" in this file.
+// The non-reserved keywords are listed below. Keywords not in this list are reserved keywords.
+ansiNonReserved
+//--ANSI-NON-RESERVED-START
+ : ADD
+ | AFTER
+ | ALTER
+ | ANALYZE
+ | ANTI
+ | ARCHIVE
+ | ARRAY
+ | ASC
+ | AT
+ | BETWEEN
+ | BUCKET
+ | BUCKETS
+ | BY
+ | CACHE
+ | CASCADE
+ | CHANGE
+ | CLEAR
+ | CLUSTER
+ | CLUSTERED
+ | CODEGEN
+ | COLLECTION
+ | COLUMNS
+ | COMMENT
+ | COMMIT
+ | COMPACT
+ | COMPACTIONS
+ | COMPUTE
+ | CONCATENATE
+ | COST
+ | CUBE
+ | CURRENT
+ | DATA
+ | DATABASE
+ | DATABASES
+ | DAY
+ | DBPROPERTIES
+ | DEFINED
+ | DELETE
+ | DELIMITED
+ | DESC
+ | DESCRIBE
+ | DFS
+ | DIRECTORIES
+ | DIRECTORY
+ | DISTRIBUTE
+ | DIV
+ | DROP
+ | ESCAPED
+ | EXCHANGE
+ | EXISTS
+ | EXPLAIN
+ | EXPORT
+ | EXTENDED
+ | EXTERNAL
+ | EXTRACT
+ | FIELDS
+ | FILEFORMAT
+ | FIRST
+ | FOLLOWING
+ | FORMAT
+ | FORMATTED
+ | FUNCTION
+ | FUNCTIONS
+ | GLOBAL
+ | GROUPING
+ | HOUR
+ | IF
+ | IGNORE
+ | IMPORT
+ | INDEX
+ | INDEXES
+ | INPATH
+ | INPUTFORMAT
+ | INSERT
+ | INTERVAL
+ | ITEMS
+ | KEYS
+ | LAST
+ | LAZY
+ | LIKE
+ | LIMIT
+ | LINES
+ | LIST
+ | LOAD
+ | LOCAL
+ | LOCATION
+ | LOCK
+ | LOCKS
+ | LOGICAL
+ | MACRO
+ | MAP
+ | MATCHED
+ | MERGE
+ | MINUTE
+ | MONTH
+ | MSCK
+ | NAMESPACE
+ | NAMESPACES
+ | NO
+ | NULLS
+ | OF
+ | OPTION
+ | OPTIONS
+ | OUT
+ | OUTPUTFORMAT
+ | OVER
+ | OVERLAY
+ | OVERWRITE
+ | PARTITION
+ | PARTITIONED
+ | PARTITIONS
+ | PERCENTLIT
+ | PIVOT
+ | PLACING
+ | POSITION
+ | PRECEDING
+ | PRINCIPALS
+ | PROPERTIES
+ | PURGE
+ | QUERY
+ | RANGE
+ | RECORDREADER
+ | RECORDWRITER
+ | RECOVER
+ | REDUCE
+ | REFRESH
+ | RENAME
+ | REPAIR
+ | REPLACE
+ | RESET
+ | RESPECT
+ | RESTRICT
+ | REVOKE
+ | RLIKE
+ | ROLE
+ | ROLES
+ | ROLLBACK
+ | ROLLUP
+ | ROW
+ | ROWS
+ | SCHEMA
+ | SECOND
+ | SEMI
+ | SEPARATED
+ | SERDE
+ | SERDEPROPERTIES
+ | SET
+ | SETMINUS
+ | SETS
+ | SHOW
+ | SKEWED
+ | SORT
+ | SORTED
+ | START
+ | STATISTICS
+ | STORED
+ | STRATIFY
+ | STRUCT
+ | SUBSTR
+ | SUBSTRING
+ | SYNC
+ | TABLES
+ | TABLESAMPLE
+ | TBLPROPERTIES
+ | TEMPORARY
+ | TERMINATED
+ | TOUCH
+ | TRANSACTION
+ | TRANSACTIONS
+ | TRANSFORM
+ | TRIM
+ | TRUE
+ | TRUNCATE
+ | TRY_CAST
+ | TYPE
+ | UNARCHIVE
+ | UNBOUNDED
+ | UNCACHE
+ | UNLOCK
+ | UNSET
+ | UPDATE
+ | USE
+ | VALUES
+ | VIEW
+ | VIEWS
+ | WINDOW
+ | YEAR
+ | ZONE
+//--ANSI-NON-RESERVED-END
+ ;
+
+// When `SQL_standard_keyword_behavior=false`, there are 2 kinds of keywords in Spark SQL.
+// - Non-reserved keywords:
+// Same definition as the one when `SQL_standard_keyword_behavior=true`.
+// - Strict-non-reserved keywords:
+// A strict version of non-reserved keywords, which can not be used as table alias.
+// You can find the full keywords list by searching "Start of the keywords list" in this file.
+// The strict-non-reserved keywords are listed in `strictNonReserved`.
+// The non-reserved keywords are listed in `nonReserved`.
+// These 2 together contain all the keywords.
+strictNonReserved
+ : ANTI
+ | CROSS
+ | EXCEPT
+ | FULL
+ | INNER
+ | INTERSECT
+ | JOIN
+ | LATERAL
+ | LEFT
+ | NATURAL
+ | ON
+ | RIGHT
+ | SEMI
+ | SETMINUS
+ | UNION
+ | USING
+ ;
+
+nonReserved
+//--DEFAULT-NON-RESERVED-START
+ : ADD
+ | AFTER
+ | ALL
+ | ALTER
+ | ANALYZE
+ | AND
+ | ANY
+ | ARCHIVE
+ | ARRAY
+ | AS
+ | ASC
+ | AT
+ | AUTHORIZATION
+ | BETWEEN
+ | BOTH
+ | BUCKET
+ | BUCKETS
+ | BY
+ | CACHE
+ | CASCADE
+ | CASE
+ | CAST
+ | CHANGE
+ | CHECK
+ | CLEAR
+ | CLUSTER
+ | CLUSTERED
+ | CODEGEN
+ | COLLATE
+ | COLLECTION
+ | COLUMN
+ | COLUMNS
+ | COMMENT
+ | COMMIT
+ | COMPACT
+ | COMPACTIONS
+ | COMPUTE
+ | CONCATENATE
+ | CONSTRAINT
+ | COST
+ | CREATE
+ | CUBE
+ | CURRENT
+ | CURRENT_DATE
+ | CURRENT_TIME
+ | CURRENT_TIMESTAMP
+ | CURRENT_USER
+ | DATA
+ | DATABASE
+ | DATABASES
+ | DAY
+ | DBPROPERTIES
+ | DEFINED
+ | DELETE
+ | DELIMITED
+ | DESC
+ | DESCRIBE
+ | DFS
+ | DIRECTORIES
+ | DIRECTORY
+ | DISTINCT
+ | DISTRIBUTE
+ | DIV
+ | DROP
+ | ELSE
+ | END
+ | ESCAPE
+ | ESCAPED
+ | EXCHANGE
+ | EXISTS
+ | EXPLAIN
+ | EXPORT
+ | EXTENDED
+ | EXTERNAL
+ | EXTRACT
+ | FALSE
+ | FETCH
+ | FILTER
+ | FIELDS
+ | FILEFORMAT
+ | FIRST
+ | FOLLOWING
+ | FOR
+ | FOREIGN
+ | FORMAT
+ | FORMATTED
+ | FROM
+ | FUNCTION
+ | FUNCTIONS
+ | GLOBAL
+ | GRANT
+ | GROUP
+ | GROUPING
+ | HAVING
+ | HOUR
+ | IF
+ | IGNORE
+ | IMPORT
+ | IN
+ | INDEX
+ | INDEXES
+ | INPATH
+ | INPUTFORMAT
+ | INSERT
+ | INTERVAL
+ | INTO
+ | IS
+ | ITEMS
+ | KEYS
+ | LAST
+ | LAZY
+ | LEADING
+ | LIKE
+ | LIMIT
+ | LINES
+ | LIST
+ | LOAD
+ | LOCAL
+ | LOCATION
+ | LOCK
+ | LOCKS
+ | LOGICAL
+ | MACRO
+ | MAP
+ | MATCHED
+ | MERGE
+ | MINUTE
+ | MONTH
+ | MSCK
+ | NAMESPACE
+ | NAMESPACES
+ | NO
+ | NOT
+ | NULL
+ | NULLS
+ | OF
+ | ONLY
+ | OPTION
+ | OPTIONS
+ | OR
+ | ORDER
+ | OUT
+ | OUTER
+ | OUTPUTFORMAT
+ | OVER
+ | OVERLAPS
+ | OVERLAY
+ | OVERWRITE
+ | PARTITION
+ | PARTITIONED
+ | PARTITIONS
+ | PERCENTLIT
+ | PIVOT
+ | PLACING
+ | POSITION
+ | PRECEDING
+ | PRIMARY
+ | PRINCIPALS
+ | PROPERTIES
+ | PURGE
+ | QUERY
+ | RANGE
+ | RECORDREADER
+ | RECORDWRITER
+ | RECOVER
+ | REDUCE
+ | REFERENCES
+ | REFRESH
+ | RENAME
+ | REPAIR
+ | REPLACE
+ | RESET
+ | RESPECT
+ | RESTRICT
+ | REVOKE
+ | RLIKE
+ | ROLE
+ | ROLES
+ | ROLLBACK
+ | ROLLUP
+ | ROW
+ | ROWS
+ | SCHEMA
+ | SECOND
+ | SELECT
+ | SEPARATED
+ | SERDE
+ | SERDEPROPERTIES
+ | SESSION_USER
+ | SET
+ | SETS
+ | SHOW
+ | SKEWED
+ | SOME
+ | SORT
+ | SORTED
+ | START
+ | STATISTICS
+ | STORED
+ | STRATIFY
+ | STRUCT
+ | SUBSTR
+ | SUBSTRING
+ | SYNC
+ | TABLE
+ | TABLES
+ | TABLESAMPLE
+ | TBLPROPERTIES
+ | TEMPORARY
+ | TERMINATED
+ | THEN
+ | TIME
+ | TO
+ | TOUCH
+ | TRAILING
+ | TRANSACTION
+ | TRANSACTIONS
+ | TRANSFORM
+ | TRIM
+ | TRUE
+ | TRUNCATE
+ | TRY_CAST
+ | TYPE
+ | UNARCHIVE
+ | UNBOUNDED
+ | UNCACHE
+ | UNIQUE
+ | UNKNOWN
+ | UNLOCK
+ | UNSET
+ | UPDATE
+ | USE
+ | USER
+ | VALUES
+ | VIEW
+ | VIEWS
+ | WHEN
+ | WHERE
+ | WINDOW
+ | WITH
+ | YEAR
+ | ZONE
+ | SYSTEM_VERSION
+ | VERSION
+ | SYSTEM_TIME
+ | TIMESTAMP
+//--DEFAULT-NON-RESERVED-END
+ ;
+
+// NOTE: If you add a new token in the list below, you should update the list of keywords
+// and reserved tag in `docs/sql-ref-ansi-compliance.md#sql-keywords`.
+
+//============================
+// Start of the keywords list
+//============================
+//--SPARK-KEYWORD-LIST-START
+ADD: 'ADD';
+AFTER: 'AFTER';
+ALL: 'ALL';
+ALTER: 'ALTER';
+ANALYZE: 'ANALYZE';
+AND: 'AND';
+ANTI: 'ANTI';
+ANY: 'ANY';
+ARCHIVE: 'ARCHIVE';
+ARRAY: 'ARRAY';
+AS: 'AS';
+ASC: 'ASC';
+AT: 'AT';
+AUTHORIZATION: 'AUTHORIZATION';
+BETWEEN: 'BETWEEN';
+BOTH: 'BOTH';
+BUCKET: 'BUCKET';
+BUCKETS: 'BUCKETS';
+BY: 'BY';
+CACHE: 'CACHE';
+CASCADE: 'CASCADE';
+CASE: 'CASE';
+CAST: 'CAST';
+CHANGE: 'CHANGE';
+CHECK: 'CHECK';
+CLEAR: 'CLEAR';
+CLUSTER: 'CLUSTER';
+CLUSTERED: 'CLUSTERED';
+CODEGEN: 'CODEGEN';
+COLLATE: 'COLLATE';
+COLLECTION: 'COLLECTION';
+COLUMN: 'COLUMN';
+COLUMNS: 'COLUMNS';
+COMMENT: 'COMMENT';
+COMMIT: 'COMMIT';
+COMPACT: 'COMPACT';
+COMPACTIONS: 'COMPACTIONS';
+COMPUTE: 'COMPUTE';
+CONCATENATE: 'CONCATENATE';
+CONSTRAINT: 'CONSTRAINT';
+COST: 'COST';
+CREATE: 'CREATE';
+CROSS: 'CROSS';
+CUBE: 'CUBE';
+CURRENT: 'CURRENT';
+CURRENT_DATE: 'CURRENT_DATE';
+CURRENT_TIME: 'CURRENT_TIME';
+CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP';
+CURRENT_USER: 'CURRENT_USER';
+DAY: 'DAY';
+DATA: 'DATA';
+DATABASE: 'DATABASE';
+DATABASES: 'DATABASES' | 'SCHEMAS';
+DBPROPERTIES: 'DBPROPERTIES';
+DEFINED: 'DEFINED';
+DELETE: 'DELETE';
+DELIMITED: 'DELIMITED';
+DESC: 'DESC';
+DESCRIBE: 'DESCRIBE';
+DFS: 'DFS';
+DIRECTORIES: 'DIRECTORIES';
+DIRECTORY: 'DIRECTORY';
+DISTINCT: 'DISTINCT';
+DISTRIBUTE: 'DISTRIBUTE';
+DIV: 'DIV';
+DROP: 'DROP';
+ELSE: 'ELSE';
+END: 'END';
+ESCAPE: 'ESCAPE';
+ESCAPED: 'ESCAPED';
+EXCEPT: 'EXCEPT';
+EXCHANGE: 'EXCHANGE';
+EXISTS: 'EXISTS';
+EXPLAIN: 'EXPLAIN';
+EXPORT: 'EXPORT';
+EXTENDED: 'EXTENDED';
+EXTERNAL: 'EXTERNAL';
+EXTRACT: 'EXTRACT';
+FALSE: 'FALSE';
+FETCH: 'FETCH';
+FIELDS: 'FIELDS';
+FILTER: 'FILTER';
+FILEFORMAT: 'FILEFORMAT';
+FIRST: 'FIRST';
+FOLLOWING: 'FOLLOWING';
+FOR: 'FOR';
+FOREIGN: 'FOREIGN';
+FORMAT: 'FORMAT';
+FORMATTED: 'FORMATTED';
+FROM: 'FROM';
+FULL: 'FULL';
+FUNCTION: 'FUNCTION';
+FUNCTIONS: 'FUNCTIONS';
+GLOBAL: 'GLOBAL';
+GRANT: 'GRANT';
+GROUP: 'GROUP';
+GROUPING: 'GROUPING';
+HAVING: 'HAVING';
+HOUR: 'HOUR';
+IF: 'IF';
+IGNORE: 'IGNORE';
+IMPORT: 'IMPORT';
+IN: 'IN';
+INDEX: 'INDEX';
+INDEXES: 'INDEXES';
+INNER: 'INNER';
+INPATH: 'INPATH';
+INPUTFORMAT: 'INPUTFORMAT';
+INSERT: 'INSERT';
+INTERSECT: 'INTERSECT';
+INTERVAL: 'INTERVAL';
+INTO: 'INTO';
+IS: 'IS';
+ITEMS: 'ITEMS';
+JOIN: 'JOIN';
+KEYS: 'KEYS';
+LAST: 'LAST';
+LATERAL: 'LATERAL';
+LAZY: 'LAZY';
+LEADING: 'LEADING';
+LEFT: 'LEFT';
+LIKE: 'LIKE';
+LIMIT: 'LIMIT';
+LINES: 'LINES';
+LIST: 'LIST';
+LOAD: 'LOAD';
+LOCAL: 'LOCAL';
+LOCATION: 'LOCATION';
+LOCK: 'LOCK';
+LOCKS: 'LOCKS';
+LOGICAL: 'LOGICAL';
+MACRO: 'MACRO';
+MAP: 'MAP';
+MATCHED: 'MATCHED';
+MERGE: 'MERGE';
+MINUTE: 'MINUTE';
+MONTH: 'MONTH';
+MSCK: 'MSCK';
+NAMESPACE: 'NAMESPACE';
+NAMESPACES: 'NAMESPACES';
+NATURAL: 'NATURAL';
+NO: 'NO';
+NOT: 'NOT' | '!';
+NULL: 'NULL';
+NULLS: 'NULLS';
+OF: 'OF';
+ON: 'ON';
+ONLY: 'ONLY';
+OPTION: 'OPTION';
+OPTIONS: 'OPTIONS';
+OR: 'OR';
+ORDER: 'ORDER';
+OUT: 'OUT';
+OUTER: 'OUTER';
+OUTPUTFORMAT: 'OUTPUTFORMAT';
+OVER: 'OVER';
+OVERLAPS: 'OVERLAPS';
+OVERLAY: 'OVERLAY';
+OVERWRITE: 'OVERWRITE';
+PARTITION: 'PARTITION';
+PARTITIONED: 'PARTITIONED';
+PARTITIONS: 'PARTITIONS';
+PERCENTLIT: 'PERCENT';
+PIVOT: 'PIVOT';
+PLACING: 'PLACING';
+POSITION: 'POSITION';
+PRECEDING: 'PRECEDING';
+PRIMARY: 'PRIMARY';
+PRINCIPALS: 'PRINCIPALS';
+PROPERTIES: 'PROPERTIES';
+PURGE: 'PURGE';
+QUERY: 'QUERY';
+RANGE: 'RANGE';
+RECORDREADER: 'RECORDREADER';
+RECORDWRITER: 'RECORDWRITER';
+RECOVER: 'RECOVER';
+REDUCE: 'REDUCE';
+REFERENCES: 'REFERENCES';
+REFRESH: 'REFRESH';
+RENAME: 'RENAME';
+REPAIR: 'REPAIR';
+REPLACE: 'REPLACE';
+RESET: 'RESET';
+RESPECT: 'RESPECT';
+RESTRICT: 'RESTRICT';
+REVOKE: 'REVOKE';
+RIGHT: 'RIGHT';
+RLIKE: 'RLIKE' | 'REGEXP';
+ROLE: 'ROLE';
+ROLES: 'ROLES';
+ROLLBACK: 'ROLLBACK';
+ROLLUP: 'ROLLUP';
+ROW: 'ROW';
+ROWS: 'ROWS';
+SECOND: 'SECOND';
+SCHEMA: 'SCHEMA';
+SELECT: 'SELECT';
+SEMI: 'SEMI';
+SEPARATED: 'SEPARATED';
+SERDE: 'SERDE';
+SERDEPROPERTIES: 'SERDEPROPERTIES';
+SESSION_USER: 'SESSION_USER';
+SET: 'SET';
+SETMINUS: 'MINUS';
+SETS: 'SETS';
+SHOW: 'SHOW';
+SKEWED: 'SKEWED';
+SOME: 'SOME';
+SORT: 'SORT';
+SORTED: 'SORTED';
+START: 'START';
+STATISTICS: 'STATISTICS';
+STORED: 'STORED';
+STRATIFY: 'STRATIFY';
+STRUCT: 'STRUCT';
+SUBSTR: 'SUBSTR';
+SUBSTRING: 'SUBSTRING';
+SYNC: 'SYNC';
+TABLE: 'TABLE';
+TABLES: 'TABLES';
+TABLESAMPLE: 'TABLESAMPLE';
+TBLPROPERTIES: 'TBLPROPERTIES';
+TEMPORARY: 'TEMPORARY' | 'TEMP';
+TERMINATED: 'TERMINATED';
+THEN: 'THEN';
+TIME: 'TIME';
+TO: 'TO';
+TOUCH: 'TOUCH';
+TRAILING: 'TRAILING';
+TRANSACTION: 'TRANSACTION';
+TRANSACTIONS: 'TRANSACTIONS';
+TRANSFORM: 'TRANSFORM';
+TRIM: 'TRIM';
+TRUE: 'TRUE';
+TRUNCATE: 'TRUNCATE';
+TRY_CAST: 'TRY_CAST';
+TYPE: 'TYPE';
+UNARCHIVE: 'UNARCHIVE';
+UNBOUNDED: 'UNBOUNDED';
+UNCACHE: 'UNCACHE';
+UNION: 'UNION';
+UNIQUE: 'UNIQUE';
+UNKNOWN: 'UNKNOWN';
+UNLOCK: 'UNLOCK';
+UNSET: 'UNSET';
+UPDATE: 'UPDATE';
+USE: 'USE';
+USER: 'USER';
+USING: 'USING';
+VALUES: 'VALUES';
+VIEW: 'VIEW';
+VIEWS: 'VIEWS';
+WHEN: 'WHEN';
+WHERE: 'WHERE';
+WINDOW: 'WINDOW';
+WITH: 'WITH';
+YEAR: 'YEAR';
+ZONE: 'ZONE';
+
+SYSTEM_VERSION: 'SYSTEM_VERSION';
+VERSION: 'VERSION';
+SYSTEM_TIME: 'SYSTEM_TIME';
+TIMESTAMP: 'TIMESTAMP';
+//--SPARK-KEYWORD-LIST-END
+//============================
+// End of the keywords list
+//============================
+
+EQ : '=' | '==';
+NSEQ: '<=>';
+NEQ : '<>';
+NEQJ: '!=';
+LT : '<';
+LTE : '<=' | '!>';
+GT : '>';
+GTE : '>=' | '!<';
+
+PLUS: '+';
+MINUS: '-';
+ASTERISK: '*';
+SLASH: '/';
+PERCENT: '%';
+TILDE: '~';
+AMPERSAND: '&';
+PIPE: '|';
+CONCAT_PIPE: '||';
+HAT: '^';
+
+STRING
+ : '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
+ | '"' ( ~('"'|'\\') | ('\\' .) )* '"'
+ ;
+
+BIGINT_LITERAL
+ : DIGIT+ 'L'
+ ;
+
+SMALLINT_LITERAL
+ : DIGIT+ 'S'
+ ;
+
+TINYINT_LITERAL
+ : DIGIT+ 'Y'
+ ;
+
+INTEGER_VALUE
+ : DIGIT+
+ ;
+
+EXPONENT_VALUE
+ : DIGIT+ EXPONENT
+ | DECIMAL_DIGITS EXPONENT {isValidDecimal()}?
+ ;
+
+DECIMAL_VALUE
+ : DECIMAL_DIGITS {isValidDecimal()}?
+ ;
+
+FLOAT_LITERAL
+ : DIGIT+ EXPONENT? 'F'
+ | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}?
+ ;
+
+DOUBLE_LITERAL
+ : DIGIT+ EXPONENT? 'D'
+ | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}?
+ ;
+
+BIGDECIMAL_LITERAL
+ : DIGIT+ EXPONENT? 'BD'
+ | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}?
+ ;
+
+IDENTIFIER
+ : (LETTER | DIGIT | '_')+
+ ;
+
+BACKQUOTED_IDENTIFIER
+ : '`' ( ~'`' | '``' )* '`'
+ ;
+
+fragment DECIMAL_DIGITS
+ : DIGIT+ '.' DIGIT*
+ | '.' DIGIT+
+ ;
+
+fragment EXPONENT
+ : 'E' [+-]? DIGIT+
+ ;
+
+fragment DIGIT
+ : [0-9]
+ ;
+
+fragment LETTER
+ : [A-Z]
+ ;
+
+SIMPLE_COMMENT
+ : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN)
+ ;
+
+BRACKETED_COMMENT
+ : '/*' {!isHint()}? (BRACKETED_COMMENT|.)*? '*/' -> channel(HIDDEN)
+ ;
+
+WS
+ : [ \r\n\t]+ -> channel(HIDDEN)
+ ;
+
+// Catch-all for anything we can't recognize.
+// We use this to be able to ignore and recover all the text
+// when splitting statements with DelimiterLexer
+UNRECOGNIZED
+ : .
+ ;
diff --git a/hudi-spark-datasource/hudi-spark3/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 b/hudi-spark-datasource/hudi-spark3/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4
new file mode 100644
index 0000000000000..0ee6bee1970c1
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+grammar HoodieSqlBase;
+
+import SqlBase;
+
+singleStatement
+ : statement EOF
+ ;
+
+statement
+ : query #queryStatement
+ | ctes? dmlStatementNoWith #dmlStatement
+ | createTableHeader ('(' colTypeList ')')? tableProvider?
+ createTableClauses
+ (AS? query)? #createTable
+ | replaceTableHeader ('(' colTypeList ')')? tableProvider?
+ createTableClauses
+ (AS? query)? #replaceTable
+ | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)?
+ VIEW (IF NOT EXISTS)? multipartIdentifier
+ identifierCommentList?
+ (commentSpec |
+ (PARTITIONED ON identifierList) |
+ (TBLPROPERTIES tablePropertyList))*
+ AS query #createView
+ | ALTER VIEW multipartIdentifier AS? query #alterViewQuery
+ | (DESC | DESCRIBE) QUERY? query #describeQuery
+ | CACHE LAZY? TABLE multipartIdentifier
+ (OPTIONS options=tablePropertyList)? (AS? query)? #cacheTable
+ | .*? #passThrough
+ ;
diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/AnalysisException.scala
new file mode 100644
index 0000000000000..1482776fd16e8
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/AnalysisException.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.annotation.Stable
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.Origin
+
+/**
+ * Thrown when a query fails to analyze, usually because the query itself is invalid.
+ *
+ * @since 1.3.0
+ */
+@Stable
+class AnalysisException protected[sql](
+ val message: String,
+ val line: Option[Int] = None,
+ val startPosition: Option[Int] = None,
+ // Some plans fail to serialize due to bugs in scala collections.
+ @transient val plan: Option[LogicalPlan] = None,
+ val cause: Option[Throwable] = None,
+ val errorClass: Option[String] = None,
+ val messageParameters: Array[String] = Array.empty)
+ extends Exception(message, cause.orNull) with Serializable {
+
+ def this(errorClass: String, messageParameters: Array[String], cause: Option[Throwable]) =
+ this(
+ SparkThrowableHelper.getMessage(errorClass, messageParameters),
+ errorClass = Some(errorClass),
+ messageParameters = messageParameters,
+ cause = cause)
+
+ def this(errorClass: String, messageParameters: Array[String]) =
+ this(errorClass = errorClass, messageParameters = messageParameters, cause = None)
+
+ def this(
+ errorClass: String,
+ messageParameters: Array[String],
+ origin: Origin) =
+ this(
+ SparkThrowableHelper.getMessage(errorClass, messageParameters),
+ line = origin.line,
+ startPosition = origin.startPosition,
+ errorClass = Some(errorClass),
+ messageParameters = messageParameters)
+
+ def copy(
+ message: String = this.message,
+ line: Option[Int] = this.line,
+ startPosition: Option[Int] = this.startPosition,
+ plan: Option[LogicalPlan] = this.plan,
+ cause: Option[Throwable] = this.cause,
+ errorClass: Option[String] = this.errorClass,
+ messageParameters: Array[String] = this.messageParameters): AnalysisException =
+ new AnalysisException(message, line, startPosition, plan, cause, errorClass, messageParameters)
+
+ def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = {
+ val newException = this.copy(line = line, startPosition = startPosition)
+ newException.setStackTrace(getStackTrace)
+ newException
+ }
+
+ override def getMessage: String = {
+ val planAnnotation = Option(plan).flatten.map(p => s";\n$p").getOrElse("")
+ getSimpleMessage + planAnnotation
+ }
+
+ // Outputs an exception without the logical plan.
+ // For testing only
+ def getSimpleMessage: String = if (line.isDefined || startPosition.isDefined) {
+ val lineAnnotation = line.map(l => s" line $l").getOrElse("")
+ val positionAnnotation = startPosition.map(p => s" pos $p").getOrElse("")
+ s"$message;$lineAnnotation$positionAnnotation"
+ } else {
+ message
+ }
+
+ def getErrorClass: String = errorClass.orNull
+ def getSqlState: String = SparkThrowableHelper.getSqlState(errorClass.orNull)
+}
diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/ErrorInfo.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/ErrorInfo.scala
new file mode 100644
index 0000000000000..6ca2af6669cd5
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/ErrorInfo.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.net.URL
+
+import scala.collection.immutable.SortedMap
+
+import com.fasterxml.jackson.annotation.JsonIgnore
+import com.fasterxml.jackson.core.`type`.TypeReference
+import com.fasterxml.jackson.databind.json.JsonMapper
+import com.fasterxml.jackson.module.scala.DefaultScalaModule
+
+import org.apache.spark.util.Utils
+
+/**
+ * Information associated with an error class.
+ *
+ * @param sqlState SQLSTATE associated with this class.
+ * @param message C-style message format compatible with printf.
+ * The error message is constructed by concatenating the lines with newlines.
+ */
+private[spark] case class ErrorInfo(message: Seq[String], sqlState: Option[String]) {
+ // For compatibility with multi-line error messages
+ @JsonIgnore
+ val messageFormat: String = message.mkString("\n")
+}
+
+/**
+ * Companion object used by instances of [[SparkThrowable]] to access error class information and
+ * construct error messages.
+ */
+private[spark] object SparkThrowableHelper {
+ val errorClassesUrl: URL =
+ Utils.getSparkClassLoader.getResource("error/error-classes.json")
+ val errorClassToInfoMap: SortedMap[String, ErrorInfo] = {
+ val mapper: JsonMapper = JsonMapper.builder()
+ .addModule(DefaultScalaModule)
+ .build()
+ mapper.readValue(errorClassesUrl, new TypeReference[SortedMap[String, ErrorInfo]]() {})
+ }
+
+ def getMessage(errorClass: String, messageParameters: Array[String]): String = {
+ val errorInfo = errorClassToInfoMap.getOrElse(errorClass,
+ throw new IllegalArgumentException(s"Cannot find error class '$errorClass'"))
+ String.format(errorInfo.messageFormat, messageParameters: _*)
+ }
+
+ def getSqlState(errorClass: String): String = {
+ Option(errorClass).flatMap(errorClassToInfoMap.get).flatMap(_.sqlState).orNull
+ }
+}
diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala
new file mode 100644
index 0000000000000..0022f0c7b1efd
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.adapter
+
+import org.apache.hudi.Spark3RowSerDe
+import org.apache.hudi.client.utils.SparkRowSerDe
+import org.apache.hudi.spark3.internal.ReflectUtil
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.{Expression, Like}
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, Join, JoinHint, LogicalPlan}
+import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile, Spark3ParsePartitionUtil, SparkParsePartitionUtil}
+import org.apache.spark.sql.hudi.SparkAdapter
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.parser.HoodieSpark3_2ExtendedSqlParser
+import org.apache.spark.sql.{Row, SparkSession}
+
+/**
+ * The adapter for spark3.
+ */
+class Spark3_2Adapter extends SparkAdapter {
+
+ override def createSparkRowSerDe(encoder: ExpressionEncoder[Row]): SparkRowSerDe = {
+ new Spark3RowSerDe(encoder)
+ }
+
+ override def toTableIdentifier(aliasId: AliasIdentifier): TableIdentifier = {
+ aliasId match {
+ case AliasIdentifier(name, Seq(database)) =>
+ TableIdentifier(name, Some(database))
+ case AliasIdentifier(name, Seq(_, database)) =>
+ TableIdentifier(name, Some(database))
+ case AliasIdentifier(name, Seq()) =>
+ TableIdentifier(name, None)
+ case _=> throw new IllegalArgumentException(s"Cannot cast $aliasId to TableIdentifier")
+ }
+ }
+
+ override def toTableIdentifier(relation: UnresolvedRelation): TableIdentifier = {
+ relation.multipartIdentifier.asTableIdentifier
+ }
+
+ override def createJoin(left: LogicalPlan, right: LogicalPlan, joinType: JoinType): Join = {
+ Join(left, right, joinType, None, JoinHint.NONE)
+ }
+
+ override def isInsertInto(plan: LogicalPlan): Boolean = {
+ plan.isInstanceOf[InsertIntoStatement]
+ }
+
+ override def getInsertIntoChildren(plan: LogicalPlan):
+ Option[(LogicalPlan, Map[String, Option[String]], LogicalPlan, Boolean, Boolean)] = {
+ plan match {
+ case insert: InsertIntoStatement =>
+ Some((insert.table, insert.partitionSpec, insert.query, insert.overwrite, insert.ifPartitionNotExists))
+ case _ =>
+ None
+ }
+ }
+
+ override def createInsertInto(table: LogicalPlan, partition: Map[String, Option[String]],
+ query: LogicalPlan, overwrite: Boolean, ifPartitionNotExists: Boolean): LogicalPlan = {
+ ReflectUtil.createInsertInto(table, partition, Seq.empty[String], query, overwrite, ifPartitionNotExists)
+ }
+
+ override def createSparkParsePartitionUtil(conf: SQLConf): SparkParsePartitionUtil = {
+ new Spark3ParsePartitionUtil(conf)
+ }
+
+ override def createLike(left: Expression, right: Expression): Expression = {
+ new Like(left, right)
+ }
+
+ override def parseMultipartIdentifier(parser: ParserInterface, sqlText: String): Seq[String] = {
+ parser.parseMultipartIdentifier(sqlText)
+ }
+
+ override def createExtendedSparkParser: Option[(SparkSession, ParserInterface) => ParserInterface] = {
+ Some(
+ (spark: SparkSession, delegate: ParserInterface) => new HoodieSpark3_2ExtendedSqlParser(spark, delegate)
+ )
+ }
+
+ /**
+ * Combine [[PartitionedFile]] to [[FilePartition]] according to `maxSplitBytes`.
+ */
+ override def getFilePartitions(
+ sparkSession: SparkSession,
+ partitionedFiles: Seq[PartitionedFile],
+ maxSplitBytes: Long): Seq[FilePartition] = {
+ FilePartition.getFilePartitions(sparkSession, partitionedFiles, maxSplitBytes)
+ }
+}
diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlAstBuilder.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlAstBuilder.scala
new file mode 100644
index 0000000000000..0524b073a1a91
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlAstBuilder.scala
@@ -0,0 +1,3871 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.parser
+
+import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode}
+import org.antlr.v4.runtime.{ParserRuleContext, Token}
+import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser._
+import org.apache.hudi.spark.sql.parser.{HoodieSqlBaseBaseVisitor, HoodieSqlBaseParser}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
+import org.apache.spark.sql.catalyst.parser.ParserUtils.{EnhancedLogicalPlan, checkDuplicateClauses, checkDuplicateKeys, entry, escapedIdentifier, operationNotAllowed, source, string, stringWithoutUnescape, validate, withOrigin}
+import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.trees.Origin
+import org.apache.spark.sql.catalyst.util.DateTimeUtils._
+import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, IntervalUtils, truncatedString}
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
+import org.apache.spark.sql.connector.catalog.{SupportsNamespaces, TableCatalog}
+import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform, Expression => V2Expression}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.util.Utils.isTesting
+import org.apache.spark.util.random.RandomSampler
+
+import java.util.Locale
+import java.util.concurrent.TimeUnit
+import javax.xml.bind.DatatypeConverter
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * The AstBuilder for HoodieSqlParser to parser the AST tree to Logical Plan.
+ * Here we only do the parser for the extended sql syntax. e.g MergeInto. For
+ * other sql syntax we use the delegate sql parser which is the SparkSqlParser.
+ */
+class HoodieSpark3_2ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterface)
+ extends HoodieSqlBaseBaseVisitor[AnyRef] with Logging {
+
+ protected def typedVisit[T](ctx: ParseTree): T = {
+ ctx.accept(this).asInstanceOf[T]
+ }
+
+ /**
+ * Override the default behavior for all visit methods. This will only return a non-null result
+ * when the context has only one child. This is done because there is no generic method to
+ * combine the results of the context children. In all other cases null is returned.
+ */
+ override def visitChildren(node: RuleNode): AnyRef = {
+ if (node.getChildCount == 1) {
+ node.getChild(0).accept(this)
+ } else {
+ null
+ }
+ }
+
+ /**
+ * Create an aliased table reference. This is typically used in FROM clauses.
+ */
+ override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) {
+ val tableId = visitMultipartIdentifier(ctx.multipartIdentifier())
+ val relation = UnresolvedRelation(tableId)
+ val table = mayApplyAliasPlan(
+ ctx.tableAlias, relation.optionalMap(ctx.temporalClause)(withTimeTravel))
+ table.optionalMap(ctx.sample)(withSample)
+ }
+
+ private def withTimeTravel(
+ ctx: TemporalClauseContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ val v = ctx.version
+ val version = if (ctx.INTEGER_VALUE != null) {
+ Some(v.getText)
+ } else {
+ Option(v).map(string)
+ }
+
+ val timestamp = Option(ctx.timestamp).map(expression)
+ if (timestamp.exists(_.references.nonEmpty)) {
+ throw new ParseException(
+ "timestamp expression cannot refer to any columns", ctx.timestamp)
+ }
+ if (timestamp.exists(e => SubqueryExpression.hasSubquery(e))) {
+ throw new ParseException(
+ "timestamp expression cannot contain subqueries", ctx.timestamp)
+ }
+
+ TimeTravelRelation(plan, timestamp, version)
+ }
+
+ // ============== The following code is fork from org.apache.spark.sql.catalyst.parser.AstBuilder
+ override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
+ visit(ctx.statement).asInstanceOf[LogicalPlan]
+ }
+
+ override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) {
+ visitNamedExpression(ctx.namedExpression)
+ }
+
+ override def visitSingleTableIdentifier(
+ ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) {
+ visitTableIdentifier(ctx.tableIdentifier)
+ }
+
+ override def visitSingleFunctionIdentifier(
+ ctx: SingleFunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) {
+ visitFunctionIdentifier(ctx.functionIdentifier)
+ }
+
+ override def visitSingleMultipartIdentifier(
+ ctx: SingleMultipartIdentifierContext): Seq[String] = withOrigin(ctx) {
+ visitMultipartIdentifier(ctx.multipartIdentifier)
+ }
+
+ override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) {
+ typedVisit[DataType](ctx.dataType)
+ }
+
+ override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = {
+ val schema = StructType(visitColTypeList(ctx.colTypeList))
+ withOrigin(ctx)(schema)
+ }
+
+ /* ********************************************************************************************
+ * Plan parsing
+ * ******************************************************************************************** */
+ protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree)
+
+ /**
+ * Create a top-level plan with Common Table Expressions.
+ */
+ override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) {
+ val query = plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses)
+
+ // Apply CTEs
+ query.optionalMap(ctx.ctes)(withCTE)
+ }
+
+ override def visitDmlStatement(ctx: DmlStatementContext): AnyRef = withOrigin(ctx) {
+ val dmlStmt = plan(ctx.dmlStatementNoWith)
+ // Apply CTEs
+ dmlStmt.optionalMap(ctx.ctes)(withCTE)
+ }
+
+ private def withCTE(ctx: CtesContext, plan: LogicalPlan): LogicalPlan = {
+ val ctes = ctx.namedQuery.asScala.map { nCtx =>
+ val namedQuery = visitNamedQuery(nCtx)
+ (namedQuery.alias, namedQuery)
+ }
+ // Check for duplicate names.
+ val duplicates = ctes.groupBy(_._1).filter(_._2.size > 1).keys
+ if (duplicates.nonEmpty) {
+ throw duplicateCteDefinitionNamesError(
+ duplicates.mkString("'", "', '", "'"), ctx)
+ }
+ UnresolvedWith(plan, ctes.toSeq)
+ }
+
+ /**
+ * Create a logical query plan for a hive-style FROM statement body.
+ */
+ private def withFromStatementBody(
+ ctx: FromStatementBodyContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // two cases for transforms and selects
+ if (ctx.transformClause != null) {
+ withTransformQuerySpecification(
+ ctx,
+ ctx.transformClause,
+ ctx.lateralView,
+ ctx.whereClause,
+ ctx.aggregationClause,
+ ctx.havingClause,
+ ctx.windowClause,
+ plan
+ )
+ } else {
+ withSelectQuerySpecification(
+ ctx,
+ ctx.selectClause,
+ ctx.lateralView,
+ ctx.whereClause,
+ ctx.aggregationClause,
+ ctx.havingClause,
+ ctx.windowClause,
+ plan
+ )
+ }
+ }
+
+ override def visitFromStatement(ctx: FromStatementContext): LogicalPlan = withOrigin(ctx) {
+ val from = visitFromClause(ctx.fromClause)
+ val selects = ctx.fromStatementBody.asScala.map { body =>
+ withFromStatementBody(body, from).
+ // Add organization statements.
+ optionalMap(body.queryOrganization)(withQueryResultClauses)
+ }
+ // If there are multiple SELECT just UNION them together into one query.
+ if (selects.length == 1) {
+ selects.head
+ } else {
+ Union(selects.toSeq)
+ }
+ }
+
+ /**
+ * Create a named logical plan.
+ *
+ * This is only used for Common Table Expressions.
+ */
+ override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) {
+ val subQuery: LogicalPlan = plan(ctx.query).optionalMap(ctx.columnAliases)(
+ (columnAliases, plan) =>
+ UnresolvedSubqueryColumnAliases(visitIdentifierList(columnAliases), plan)
+ )
+ SubqueryAlias(ctx.name.getText, subQuery)
+ }
+
+ /**
+ * Create a logical plan which allows for multiple inserts using one 'from' statement. These
+ * queries have the following SQL form:
+ * {{{
+ * [WITH cte...]?
+ * FROM src
+ * [INSERT INTO tbl1 SELECT *]+
+ * }}}
+ * For example:
+ * {{{
+ * FROM db.tbl1 A
+ * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5
+ * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12
+ * }}}
+ * This (Hive) feature cannot be combined with set-operators.
+ */
+ override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) {
+ val from = visitFromClause(ctx.fromClause)
+
+ // Build the insert clauses.
+ val inserts = ctx.multiInsertQueryBody.asScala.map { body =>
+ withInsertInto(body.insertInto,
+ withFromStatementBody(body.fromStatementBody, from).
+ optionalMap(body.fromStatementBody.queryOrganization)(withQueryResultClauses))
+ }
+
+ // If there are multiple INSERTS just UNION them together into one query.
+ if (inserts.length == 1) {
+ inserts.head
+ } else {
+ Union(inserts.toSeq)
+ }
+ }
+
+ /**
+ * Create a logical plan for a regular (single-insert) query.
+ */
+ override def visitSingleInsertQuery(
+ ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) {
+ withInsertInto(
+ ctx.insertInto(),
+ plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses))
+ }
+
+ /**
+ * Parameters used for writing query to a table:
+ * (UnresolvedRelation, tableColumnList, partitionKeys, ifPartitionNotExists).
+ */
+ type InsertTableParams = (UnresolvedRelation, Seq[String], Map[String, Option[String]], Boolean)
+
+ /**
+ * Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider).
+ */
+ type InsertDirParams = (Boolean, CatalogStorageFormat, Option[String])
+
+ /**
+ * Add an
+ * {{{
+ * INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]? [identifierList]
+ * INSERT INTO [TABLE] tableIdentifier [partitionSpec] [identifierList]
+ * INSERT OVERWRITE [LOCAL] DIRECTORY STRING [rowFormat] [createFileFormat]
+ * INSERT OVERWRITE [LOCAL] DIRECTORY [STRING] tableProvider [OPTIONS tablePropertyList]
+ * }}}
+ * operation to logical plan
+ */
+ private def withInsertInto(
+ ctx: InsertIntoContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ ctx match {
+ case table: InsertIntoTableContext =>
+ val (relation, cols, partition, ifPartitionNotExists) = visitInsertIntoTable(table)
+ InsertIntoStatement(
+ relation,
+ partition,
+ cols,
+ query,
+ overwrite = false,
+ ifPartitionNotExists)
+ case table: InsertOverwriteTableContext =>
+ val (relation, cols, partition, ifPartitionNotExists) = visitInsertOverwriteTable(table)
+ InsertIntoStatement(
+ relation,
+ partition,
+ cols,
+ query,
+ overwrite = true,
+ ifPartitionNotExists)
+ case dir: InsertOverwriteDirContext =>
+ val (isLocal, storage, provider) = visitInsertOverwriteDir(dir)
+ InsertIntoDir(isLocal, storage, provider, query, overwrite = true)
+ case hiveDir: InsertOverwriteHiveDirContext =>
+ val (isLocal, storage, provider) = visitInsertOverwriteHiveDir(hiveDir)
+ InsertIntoDir(isLocal, storage, provider, query, overwrite = true)
+ case _ =>
+ throw invalidInsertIntoError(ctx)
+ }
+ }
+
+ /**
+ * Add an INSERT INTO TABLE operation to the logical plan.
+ */
+ override def visitInsertIntoTable(
+ ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) {
+ val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil)
+ val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
+
+ if (ctx.EXISTS != null) {
+ operationNotAllowed("INSERT INTO ... IF NOT EXISTS", ctx)
+ }
+
+ (createUnresolvedRelation(ctx.multipartIdentifier), cols, partitionKeys, false)
+ }
+
+ /**
+ * Add an INSERT OVERWRITE TABLE operation to the logical plan.
+ */
+ override def visitInsertOverwriteTable(
+ ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) {
+ assert(ctx.OVERWRITE() != null)
+ val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil)
+ val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
+
+ val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty)
+ if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) {
+ operationNotAllowed("IF NOT EXISTS with dynamic partitions: " +
+ dynamicPartitionKeys.keys.mkString(", "), ctx)
+ }
+
+ (createUnresolvedRelation(ctx.multipartIdentifier), cols, partitionKeys, ctx.EXISTS() != null)
+ }
+
+ /**
+ * Write to a directory, returning a [[InsertIntoDir]] logical plan.
+ */
+ override def visitInsertOverwriteDir(
+ ctx: InsertOverwriteDirContext): InsertDirParams = withOrigin(ctx) {
+ throw insertOverwriteDirectoryUnsupportedError(ctx)
+ }
+
+ /**
+ * Write to a directory, returning a [[InsertIntoDir]] logical plan.
+ */
+ override def visitInsertOverwriteHiveDir(
+ ctx: InsertOverwriteHiveDirContext): InsertDirParams = withOrigin(ctx) {
+ throw insertOverwriteDirectoryUnsupportedError(ctx)
+ }
+
+ private def getTableAliasWithoutColumnAlias(
+ ctx: TableAliasContext, op: String): Option[String] = {
+ if (ctx == null) {
+ None
+ } else {
+ val ident = ctx.strictIdentifier()
+ if (ctx.identifierList() != null) {
+ throw columnAliasInOperationNotAllowedError(op, ctx)
+ }
+ if (ident != null) Some(ident.getText) else None
+ }
+ }
+
+ override def visitDeleteFromTable(
+ ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) {
+ val table = createUnresolvedRelation(ctx.multipartIdentifier())
+ val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE")
+ val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
+ val predicate = if (ctx.whereClause() != null) {
+ Some(expression(ctx.whereClause().booleanExpression()))
+ } else {
+ None
+ }
+ DeleteFromTable(aliasedTable, predicate)
+ }
+
+ override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) {
+ val table = createUnresolvedRelation(ctx.multipartIdentifier())
+ val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE")
+ val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
+ val assignments = withAssignments(ctx.setClause().assignmentList())
+ val predicate = if (ctx.whereClause() != null) {
+ Some(expression(ctx.whereClause().booleanExpression()))
+ } else {
+ None
+ }
+
+ UpdateTable(aliasedTable, assignments, predicate)
+ }
+
+ private def withAssignments(assignCtx: AssignmentListContext): Seq[Assignment] =
+ withOrigin(assignCtx) {
+ assignCtx.assignment().asScala.map { assign =>
+ Assignment(UnresolvedAttribute(visitMultipartIdentifier(assign.key)),
+ expression(assign.value))
+ }.toSeq
+ }
+
+ override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) {
+ val targetTable = createUnresolvedRelation(ctx.target)
+ val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE")
+ val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable)
+
+ val sourceTableOrQuery = if (ctx.source != null) {
+ createUnresolvedRelation(ctx.source)
+ } else if (ctx.sourceQuery != null) {
+ visitQuery(ctx.sourceQuery)
+ } else {
+ throw emptySourceForMergeError(ctx)
+ }
+ val sourceTableAlias = getTableAliasWithoutColumnAlias(ctx.sourceAlias, "MERGE")
+ val aliasedSource =
+ sourceTableAlias.map(SubqueryAlias(_, sourceTableOrQuery)).getOrElse(sourceTableOrQuery)
+
+ val mergeCondition = expression(ctx.mergeCondition)
+
+ val matchedActions = ctx.matchedClause().asScala.map {
+ clause => {
+ if (clause.matchedAction().DELETE() != null) {
+ DeleteAction(Option(clause.matchedCond).map(expression))
+ } else if (clause.matchedAction().UPDATE() != null) {
+ val condition = Option(clause.matchedCond).map(expression)
+ if (clause.matchedAction().ASTERISK() != null) {
+ UpdateStarAction(condition)
+ } else {
+ UpdateAction(condition, withAssignments(clause.matchedAction().assignmentList()))
+ }
+ } else {
+ // It should not be here.
+ throw unrecognizedMatchedActionError(clause)
+ }
+ }
+ }
+ val notMatchedActions = ctx.notMatchedClause().asScala.map {
+ clause => {
+ if (clause.notMatchedAction().INSERT() != null) {
+ val condition = Option(clause.notMatchedCond).map(expression)
+ if (clause.notMatchedAction().ASTERISK() != null) {
+ InsertStarAction(condition)
+ } else {
+ val columns = clause.notMatchedAction().columns.multipartIdentifier()
+ .asScala.map(attr => UnresolvedAttribute(visitMultipartIdentifier(attr)))
+ val values = clause.notMatchedAction().expression().asScala.map(expression)
+ if (columns.size != values.size) {
+ throw insertedValueNumberNotMatchFieldNumberError(clause)
+ }
+ InsertAction(condition, columns.zip(values).map(kv => Assignment(kv._1, kv._2)).toSeq)
+ }
+ } else {
+ // It should not be here.
+ throw unrecognizedNotMatchedActionError(clause)
+ }
+ }
+ }
+ if (matchedActions.isEmpty && notMatchedActions.isEmpty) {
+ throw mergeStatementWithoutWhenClauseError(ctx)
+ }
+ // children being empty means that the condition is not set
+ val matchedActionSize = matchedActions.length
+ if (matchedActionSize >= 2 && !matchedActions.init.forall(_.condition.nonEmpty)) {
+ throw nonLastMatchedClauseOmitConditionError(ctx)
+ }
+ val notMatchedActionSize = notMatchedActions.length
+ if (notMatchedActionSize >= 2 && !notMatchedActions.init.forall(_.condition.nonEmpty)) {
+ throw nonLastNotMatchedClauseOmitConditionError(ctx)
+ }
+
+ MergeIntoTable(
+ aliasedTarget,
+ aliasedSource,
+ mergeCondition,
+ matchedActions.toSeq,
+ notMatchedActions.toSeq)
+ }
+
+ /**
+ * Create a partition specification map.
+ */
+ override def visitPartitionSpec(
+ ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) {
+ val legacyNullAsString =
+ conf.getConf(SQLConf.LEGACY_PARSE_NULL_PARTITION_SPEC_AS_STRING_LITERAL)
+ val parts = ctx.partitionVal.asScala.map { pVal =>
+ val name = pVal.identifier.getText
+ val value = Option(pVal.constant).map(v => visitStringConstant(v, legacyNullAsString))
+ name -> value
+ }
+ // Before calling `toMap`, we check duplicated keys to avoid silently ignore partition values
+ // in partition spec like PARTITION(a='1', b='2', a='3'). The real semantical check for
+ // partition columns will be done in analyzer.
+ if (conf.caseSensitiveAnalysis) {
+ checkDuplicateKeys(parts.toSeq, ctx)
+ } else {
+ checkDuplicateKeys(parts.map(kv => kv._1.toLowerCase(Locale.ROOT) -> kv._2).toSeq, ctx)
+ }
+ parts.toMap
+ }
+
+ /**
+ * Create a partition specification map without optional values.
+ */
+ protected def visitNonOptionalPartitionSpec(
+ ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) {
+ visitPartitionSpec(ctx).map {
+ case (key, None) => throw emptyPartitionKeyError(key, ctx)
+ case (key, Some(value)) => key -> value
+ }
+ }
+
+ /**
+ * Convert a constant of any type into a string. This is typically used in DDL commands, and its
+ * main purpose is to prevent slight differences due to back to back conversions i.e.:
+ * String -> Literal -> String.
+ */
+ protected def visitStringConstant(
+ ctx: ConstantContext,
+ legacyNullAsString: Boolean): String = withOrigin(ctx) {
+ expression(ctx) match {
+ case Literal(null, _) if !legacyNullAsString => null
+ case l @ Literal(null, _) => l.toString
+ case l: Literal =>
+ // TODO For v2 commands, we will cast the string back to its actual value,
+ // which is a waste and can be improved in the future.
+ Cast(l, StringType, Some(conf.sessionLocalTimeZone)).eval().toString
+ case other =>
+ throw new IllegalArgumentException(s"Only literals are allowed in the " +
+ s"partition spec, but got ${other.sql}")
+ }
+ }
+
+ /**
+ * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These
+ * clauses determine the shape (ordering/partitioning/rows) of the query result.
+ */
+ private def withQueryResultClauses(
+ ctx: QueryOrganizationContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+
+ // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause.
+ val withOrder = if (
+ !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // ORDER BY ...
+ Sort(order.asScala.map(visitSortItem).toSeq, global = true, query)
+ } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // SORT BY ...
+ Sort(sort.asScala.map(visitSortItem).toSeq, global = false, query)
+ } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
+ // DISTRIBUTE BY ...
+ withRepartitionByExpression(ctx, expressionList(distributeBy), query)
+ } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
+ // SORT BY ... DISTRIBUTE BY ...
+ Sort(
+ sort.asScala.map(visitSortItem).toSeq,
+ global = false,
+ withRepartitionByExpression(ctx, expressionList(distributeBy), query))
+ } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) {
+ // CLUSTER BY ...
+ val expressions = expressionList(clusterBy)
+ Sort(
+ expressions.map(SortOrder(_, Ascending)),
+ global = false,
+ withRepartitionByExpression(ctx, expressions, query))
+ } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // [EMPTY]
+ query
+ } else {
+ throw combinationQueryResultClausesUnsupportedError(ctx)
+ }
+
+ // WINDOWS
+ val withWindow = withOrder.optionalMap(windowClause)(withWindowClause)
+
+ // LIMIT
+ // - LIMIT ALL is the same as omitting the LIMIT clause
+ withWindow.optional(limit) {
+ Limit(typedVisit(limit), withWindow)
+ }
+ }
+
+ /**
+ * Create a clause for DISTRIBUTE BY.
+ */
+ protected def withRepartitionByExpression(
+ ctx: QueryOrganizationContext,
+ expressions: Seq[Expression],
+ query: LogicalPlan): LogicalPlan = {
+ throw distributeByUnsupportedError(ctx)
+ }
+
+ override def visitTransformQuerySpecification(
+ ctx: TransformQuerySpecificationContext): LogicalPlan = withOrigin(ctx) {
+ val from = OneRowRelation().optional(ctx.fromClause) {
+ visitFromClause(ctx.fromClause)
+ }
+ withTransformQuerySpecification(
+ ctx,
+ ctx.transformClause,
+ ctx.lateralView,
+ ctx.whereClause,
+ ctx.aggregationClause,
+ ctx.havingClause,
+ ctx.windowClause,
+ from
+ )
+ }
+
+ override def visitRegularQuerySpecification(
+ ctx: RegularQuerySpecificationContext): LogicalPlan = withOrigin(ctx) {
+ val from = OneRowRelation().optional(ctx.fromClause) {
+ visitFromClause(ctx.fromClause)
+ }
+ withSelectQuerySpecification(
+ ctx,
+ ctx.selectClause,
+ ctx.lateralView,
+ ctx.whereClause,
+ ctx.aggregationClause,
+ ctx.havingClause,
+ ctx.windowClause,
+ from
+ )
+ }
+
+ override def visitNamedExpressionSeq(
+ ctx: NamedExpressionSeqContext): Seq[Expression] = {
+ Option(ctx).toSeq
+ .flatMap(_.namedExpression.asScala)
+ .map(typedVisit[Expression])
+ }
+
+ override def visitExpressionSeq(ctx: ExpressionSeqContext): Seq[Expression] = {
+ Option(ctx).toSeq
+ .flatMap(_.expression.asScala)
+ .map(typedVisit[Expression])
+ }
+
+ /**
+ * Create a logical plan using a having clause.
+ */
+ private def withHavingClause(
+ ctx: HavingClauseContext, plan: LogicalPlan): LogicalPlan = {
+ // Note that we add a cast to non-predicate expressions. If the expression itself is
+ // already boolean, the optimizer will get rid of the unnecessary cast.
+ val predicate = expression(ctx.booleanExpression) match {
+ case p: Predicate => p
+ case e => Cast(e, BooleanType)
+ }
+ UnresolvedHaving(predicate, plan)
+ }
+
+ /**
+ * Create a logical plan using a where clause.
+ */
+ private def withWhereClause(ctx: WhereClauseContext, plan: LogicalPlan): LogicalPlan = {
+ Filter(expression(ctx.booleanExpression), plan)
+ }
+
+ /**
+ * Add a hive-style transform (SELECT TRANSFORM/MAP/REDUCE) query specification to a logical plan.
+ */
+ private def withTransformQuerySpecification(
+ ctx: ParserRuleContext,
+ transformClause: TransformClauseContext,
+ lateralView: java.util.List[LateralViewContext],
+ whereClause: WhereClauseContext,
+ aggregationClause: AggregationClauseContext,
+ havingClause: HavingClauseContext,
+ windowClause: WindowClauseContext,
+ relation: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ if (transformClause.setQuantifier != null) {
+ throw transformNotSupportQuantifierError(transformClause.setQuantifier)
+ }
+ // Create the attributes.
+ val (attributes, schemaLess) = if (transformClause.colTypeList != null) {
+ // Typed return columns.
+ (createSchema(transformClause.colTypeList).toAttributes, false)
+ } else if (transformClause.identifierSeq != null) {
+ // Untyped return columns.
+ val attrs = visitIdentifierSeq(transformClause.identifierSeq).map { name =>
+ AttributeReference(name, StringType, nullable = true)()
+ }
+ (attrs, false)
+ } else {
+ (Seq(AttributeReference("key", StringType)(),
+ AttributeReference("value", StringType)()), true)
+ }
+
+ val plan = visitCommonSelectQueryClausePlan(
+ relation,
+ visitExpressionSeq(transformClause.expressionSeq),
+ lateralView,
+ whereClause,
+ aggregationClause,
+ havingClause,
+ windowClause,
+ isDistinct = false)
+
+ ScriptTransformation(
+ string(transformClause.script),
+ attributes,
+ plan,
+ withScriptIOSchema(
+ ctx,
+ transformClause.inRowFormat,
+ transformClause.recordWriter,
+ transformClause.outRowFormat,
+ transformClause.recordReader,
+ schemaLess
+ )
+ )
+ }
+
+ /**
+ * Add a regular (SELECT) query specification to a logical plan. The query specification
+ * is the core of the logical plan, this is where sourcing (FROM clause), projection (SELECT),
+ * aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place.
+ *
+ * Note that query hints are ignored (both by the parser and the builder).
+ */
+ private def withSelectQuerySpecification(
+ ctx: ParserRuleContext,
+ selectClause: SelectClauseContext,
+ lateralView: java.util.List[LateralViewContext],
+ whereClause: WhereClauseContext,
+ aggregationClause: AggregationClauseContext,
+ havingClause: HavingClauseContext,
+ windowClause: WindowClauseContext,
+ relation: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ val isDistinct = selectClause.setQuantifier() != null &&
+ selectClause.setQuantifier().DISTINCT() != null
+
+ val plan = visitCommonSelectQueryClausePlan(
+ relation,
+ visitNamedExpressionSeq(selectClause.namedExpressionSeq),
+ lateralView,
+ whereClause,
+ aggregationClause,
+ havingClause,
+ windowClause,
+ isDistinct)
+
+ // Hint
+ selectClause.hints.asScala.foldRight(plan)(withHints)
+ }
+
+ def visitCommonSelectQueryClausePlan(
+ relation: LogicalPlan,
+ expressions: Seq[Expression],
+ lateralView: java.util.List[LateralViewContext],
+ whereClause: WhereClauseContext,
+ aggregationClause: AggregationClauseContext,
+ havingClause: HavingClauseContext,
+ windowClause: WindowClauseContext,
+ isDistinct: Boolean): LogicalPlan = {
+ // Add lateral views.
+ val withLateralView = lateralView.asScala.foldLeft(relation)(withGenerate)
+
+ // Add where.
+ val withFilter = withLateralView.optionalMap(whereClause)(withWhereClause)
+
+ // Add aggregation or a project.
+ val namedExpressions = expressions.map {
+ case e: NamedExpression => e
+ case e: Expression => UnresolvedAlias(e)
+ }
+
+ def createProject() = if (namedExpressions.nonEmpty) {
+ Project(namedExpressions, withFilter)
+ } else {
+ withFilter
+ }
+
+ val withProject = if (aggregationClause == null && havingClause != null) {
+ if (conf.getConf(SQLConf.LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE)) {
+ // If the legacy conf is set, treat HAVING without GROUP BY as WHERE.
+ val predicate = expression(havingClause.booleanExpression) match {
+ case p: Predicate => p
+ case e => Cast(e, BooleanType)
+ }
+ Filter(predicate, createProject())
+ } else {
+ // According to SQL standard, HAVING without GROUP BY means global aggregate.
+ withHavingClause(havingClause, Aggregate(Nil, namedExpressions, withFilter))
+ }
+ } else if (aggregationClause != null) {
+ val aggregate = withAggregationClause(aggregationClause, namedExpressions, withFilter)
+ aggregate.optionalMap(havingClause)(withHavingClause)
+ } else {
+ // When hitting this branch, `having` must be null.
+ createProject()
+ }
+
+ // Distinct
+ val withDistinct = if (isDistinct) {
+ Distinct(withProject)
+ } else {
+ withProject
+ }
+
+ // Window
+ val withWindow = withDistinct.optionalMap(windowClause)(withWindowClause)
+
+ withWindow
+ }
+
+ // Script Transform's input/output format.
+ type ScriptIOFormat =
+ (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String])
+
+ protected def getRowFormatDelimited(ctx: RowFormatDelimitedContext): ScriptIOFormat = {
+ // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema
+ // expects a seq of pairs in which the old parsers' token names are used as keys.
+ // Transforming the result of visitRowFormatDelimited would be quite a bit messier than
+ // retrieving the key value pairs ourselves.
+ val entries = entry("TOK_TABLEROWFORMATFIELD", ctx.fieldsTerminatedBy) ++
+ entry("TOK_TABLEROWFORMATCOLLITEMS", ctx.collectionItemsTerminatedBy) ++
+ entry("TOK_TABLEROWFORMATMAPKEYS", ctx.keysTerminatedBy) ++
+ entry("TOK_TABLEROWFORMATNULL", ctx.nullDefinedAs) ++
+ Option(ctx.linesSeparatedBy).toSeq.map { token =>
+ val value = string(token)
+ validate(
+ value == "\n",
+ s"LINES TERMINATED BY only supports newline '\\n' right now: $value",
+ ctx)
+ "TOK_TABLEROWFORMATLINES" -> value
+ }
+
+ (entries, None, Seq.empty, None)
+ }
+
+ /**
+ * Create a [[ScriptInputOutputSchema]].
+ */
+ protected def withScriptIOSchema(
+ ctx: ParserRuleContext,
+ inRowFormat: RowFormatContext,
+ recordWriter: Token,
+ outRowFormat: RowFormatContext,
+ recordReader: Token,
+ schemaLess: Boolean): ScriptInputOutputSchema = {
+
+ def format(fmt: RowFormatContext): ScriptIOFormat = fmt match {
+ case c: RowFormatDelimitedContext =>
+ getRowFormatDelimited(c)
+
+ case c: RowFormatSerdeContext =>
+ throw transformWithSerdeUnsupportedError(ctx)
+
+ // SPARK-32106: When there is no definition about format, we return empty result
+ // to use a built-in default Serde in SparkScriptTransformationExec.
+ case null =>
+ (Nil, None, Seq.empty, None)
+ }
+
+ val (inFormat, inSerdeClass, inSerdeProps, reader) = format(inRowFormat)
+
+ val (outFormat, outSerdeClass, outSerdeProps, writer) = format(outRowFormat)
+
+ ScriptInputOutputSchema(
+ inFormat, outFormat,
+ inSerdeClass, outSerdeClass,
+ inSerdeProps, outSerdeProps,
+ reader, writer,
+ schemaLess)
+ }
+
+ /**
+ * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma
+ * separated) relations here, these get converted into a single plan by condition-less inner join.
+ */
+ override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) {
+ val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) =>
+ val right = plan(relation.relationPrimary)
+ val join = right.optionalMap(left) { (left, right) =>
+ if (relation.LATERAL != null) {
+ if (!relation.relationPrimary.isInstanceOf[AliasedQueryContext]) {
+ throw invalidLateralJoinRelationError(relation.relationPrimary)
+ }
+ LateralJoin(left, LateralSubquery(right), Inner, None)
+ } else {
+ Join(left, right, Inner, None, JoinHint.NONE)
+ }
+ }
+ withJoinRelations(join, relation)
+ }
+ if (ctx.pivotClause() != null) {
+ if (!ctx.lateralView.isEmpty) {
+ throw lateralWithPivotInFromClauseNotAllowedError(ctx)
+ }
+ withPivot(ctx.pivotClause, from)
+ } else {
+ ctx.lateralView.asScala.foldLeft(from)(withGenerate)
+ }
+ }
+
+ /**
+ * Connect two queries by a Set operator.
+ *
+ * Supported Set operators are:
+ * - UNION [ DISTINCT | ALL ]
+ * - EXCEPT [ DISTINCT | ALL ]
+ * - MINUS [ DISTINCT | ALL ]
+ * - INTERSECT [DISTINCT | ALL]
+ */
+ override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) {
+ val left = plan(ctx.left)
+ val right = plan(ctx.right)
+ val all = Option(ctx.setQuantifier()).exists(_.ALL != null)
+ ctx.operator.getType match {
+ case HoodieSqlBaseParser.UNION if all =>
+ Union(left, right)
+ case HoodieSqlBaseParser.UNION =>
+ Distinct(Union(left, right))
+ case HoodieSqlBaseParser.INTERSECT if all =>
+ Intersect(left, right, isAll = true)
+ case HoodieSqlBaseParser.INTERSECT =>
+ Intersect(left, right, isAll = false)
+ case HoodieSqlBaseParser.EXCEPT if all =>
+ Except(left, right, isAll = true)
+ case HoodieSqlBaseParser.EXCEPT =>
+ Except(left, right, isAll = false)
+ case HoodieSqlBaseParser.SETMINUS if all =>
+ Except(left, right, isAll = true)
+ case HoodieSqlBaseParser.SETMINUS =>
+ Except(left, right, isAll = false)
+ }
+ }
+
+ /**
+ * Add a [[WithWindowDefinition]] operator to a logical plan.
+ */
+ private def withWindowClause(
+ ctx: WindowClauseContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // Collect all window specifications defined in the WINDOW clause.
+ val baseWindowTuples = ctx.namedWindow.asScala.map {
+ wCtx =>
+ (wCtx.name.getText, typedVisit[WindowSpec](wCtx.windowSpec))
+ }
+ baseWindowTuples.groupBy(_._1).foreach { kv =>
+ if (kv._2.size > 1) {
+ throw repetitiveWindowDefinitionError(kv._1, ctx)
+ }
+ }
+ val baseWindowMap = baseWindowTuples.toMap
+
+ // Handle cases like
+ // window w1 as (partition by p_mfgr order by p_name
+ // range between 2 preceding and 2 following),
+ // w2 as w1
+ val windowMapView = baseWindowMap.mapValues {
+ case WindowSpecReference(name) =>
+ baseWindowMap.get(name) match {
+ case Some(spec: WindowSpecDefinition) =>
+ spec
+ case Some(ref) =>
+ throw invalidWindowReferenceError(name, ctx)
+ case None =>
+ throw cannotResolveWindowReferenceError(name, ctx)
+ }
+ case spec: WindowSpecDefinition => spec
+ }
+
+ // Note that mapValues creates a view instead of materialized map. We force materialization by
+ // mapping over identity.
+ WithWindowDefinition(windowMapView.map(identity).toMap, query)
+ }
+
+ /**
+ * Add an [[Aggregate]] to a logical plan.
+ */
+ private def withAggregationClause(
+ ctx: AggregationClauseContext,
+ selectExpressions: Seq[NamedExpression],
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ if (ctx.groupingExpressionsWithGroupingAnalytics.isEmpty) {
+ val groupByExpressions = expressionList(ctx.groupingExpressions)
+ if (ctx.GROUPING != null) {
+ // GROUP BY ... GROUPING SETS (...)
+ // `groupByExpressions` can be non-empty for Hive compatibility. It may add extra grouping
+ // expressions that do not exist in GROUPING SETS (...), and the value is always null.
+ // For example, `SELECT a, b, c FROM ... GROUP BY a, b, c GROUPING SETS (a, b)`, the output
+ // of column `c` is always null.
+ val groupingSets =
+ ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)).toSeq)
+ Aggregate(Seq(GroupingSets(groupingSets.toSeq, groupByExpressions)),
+ selectExpressions, query)
+ } else {
+ // GROUP BY .... (WITH CUBE | WITH ROLLUP)?
+ val mappedGroupByExpressions = if (ctx.CUBE != null) {
+ Seq(Cube(groupByExpressions.map(Seq(_))))
+ } else if (ctx.ROLLUP != null) {
+ Seq(Rollup(groupByExpressions.map(Seq(_))))
+ } else {
+ groupByExpressions
+ }
+ Aggregate(mappedGroupByExpressions, selectExpressions, query)
+ }
+ } else {
+ val groupByExpressions =
+ ctx.groupingExpressionsWithGroupingAnalytics.asScala
+ .map(groupByExpr => {
+ val groupingAnalytics = groupByExpr.groupingAnalytics
+ if (groupingAnalytics != null) {
+ visitGroupingAnalytics(groupingAnalytics)
+ } else {
+ expression(groupByExpr.expression)
+ }
+ })
+ Aggregate(groupByExpressions.toSeq, selectExpressions, query)
+ }
+ }
+
+ override def visitGroupingAnalytics(
+ groupingAnalytics: GroupingAnalyticsContext): BaseGroupingSets = {
+ val groupingSets = groupingAnalytics.groupingSet.asScala
+ .map(_.expression.asScala.map(e => expression(e)).toSeq)
+ if (groupingAnalytics.CUBE != null) {
+ // CUBE(A, B, (A, B), ()) is not supported.
+ if (groupingSets.exists(_.isEmpty)) {
+ throw invalidGroupingSetError("CUBE", groupingAnalytics)
+ }
+ Cube(groupingSets.toSeq)
+ } else if (groupingAnalytics.ROLLUP != null) {
+ // ROLLUP(A, B, (A, B), ()) is not supported.
+ if (groupingSets.exists(_.isEmpty)) {
+ throw invalidGroupingSetError("ROLLUP", groupingAnalytics)
+ }
+ Rollup(groupingSets.toSeq)
+ } else {
+ assert(groupingAnalytics.GROUPING != null && groupingAnalytics.SETS != null)
+ val groupingSets = groupingAnalytics.groupingElement.asScala.flatMap { expr =>
+ val groupingAnalytics = expr.groupingAnalytics()
+ if (groupingAnalytics != null) {
+ visitGroupingAnalytics(groupingAnalytics).selectedGroupByExprs
+ } else {
+ Seq(expr.groupingSet().expression().asScala.map(e => expression(e)).toSeq)
+ }
+ }
+ GroupingSets(groupingSets.toSeq)
+ }
+ }
+
+ /**
+ * Add [[UnresolvedHint]]s to a logical plan.
+ */
+ private def withHints(
+ ctx: HintContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ var plan = query
+ ctx.hintStatements.asScala.reverse.foreach { stmt =>
+ plan = UnresolvedHint(stmt.hintName.getText,
+ stmt.parameters.asScala.map(expression).toSeq, plan)
+ }
+ plan
+ }
+
+ /**
+ * Add a [[Pivot]] to a logical plan.
+ */
+ private def withPivot(
+ ctx: PivotClauseContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ val aggregates = Option(ctx.aggregates).toSeq
+ .flatMap(_.namedExpression.asScala)
+ .map(typedVisit[Expression])
+ val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) {
+ UnresolvedAttribute.quoted(ctx.pivotColumn.identifier.getText)
+ } else {
+ CreateStruct(
+ ctx.pivotColumn.identifiers.asScala.map(
+ identifier => UnresolvedAttribute.quoted(identifier.getText)).toSeq)
+ }
+ val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue)
+ Pivot(None, pivotColumn, pivotValues.toSeq, aggregates, query)
+ }
+
+ /**
+ * Create a Pivot column value with or without an alias.
+ */
+ override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.expression)
+ if (ctx.identifier != null) {
+ Alias(e, ctx.identifier.getText)()
+ } else {
+ e
+ }
+ }
+
+ /**
+ * Add a [[Generate]] (Lateral View) to a logical plan.
+ */
+ private def withGenerate(
+ query: LogicalPlan,
+ ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) {
+ val expressions = expressionList(ctx.expression)
+ Generate(
+ UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions),
+ unrequiredChildIndex = Nil,
+ outer = ctx.OUTER != null,
+ // scalastyle:off caselocale
+ Some(ctx.tblName.getText.toLowerCase),
+ // scalastyle:on caselocale
+ ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.quoted).toSeq,
+ query)
+ }
+
+ /**
+ * Create a single relation referenced in a FROM clause. This method is used when a part of the
+ * join condition is nested, for example:
+ * {{{
+ * select * from t1 join (t2 cross join t3) on col1 = col2
+ * }}}
+ */
+ override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) {
+ withJoinRelations(plan(ctx.relationPrimary), ctx)
+ }
+
+ /**
+ * Join one more [[LogicalPlan]]s to the current logical plan.
+ */
+ private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = {
+ ctx.joinRelation.asScala.foldLeft(base) { (left, join) =>
+ withOrigin(join) {
+ val baseJoinType = join.joinType match {
+ case null => Inner
+ case jt if jt.CROSS != null => Cross
+ case jt if jt.FULL != null => FullOuter
+ case jt if jt.SEMI != null => LeftSemi
+ case jt if jt.ANTI != null => LeftAnti
+ case jt if jt.LEFT != null => LeftOuter
+ case jt if jt.RIGHT != null => RightOuter
+ case _ => Inner
+ }
+
+ if (join.LATERAL != null && !join.right.isInstanceOf[AliasedQueryContext]) {
+ throw invalidLateralJoinRelationError(join.right)
+ }
+
+ // Resolve the join type and join condition
+ val (joinType, condition) = Option(join.joinCriteria) match {
+ case Some(c) if c.USING != null =>
+ if (join.LATERAL != null) {
+ throw lateralJoinWithUsingJoinUnsupportedError(ctx)
+ }
+ (UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None)
+ case Some(c) if c.booleanExpression != null =>
+ (baseJoinType, Option(expression(c.booleanExpression)))
+ case Some(c) =>
+ throw joinCriteriaUnimplementedError(c, ctx)
+ case None if join.NATURAL != null =>
+ if (join.LATERAL != null) {
+ throw lateralJoinWithNaturalJoinUnsupportedError(ctx)
+ }
+ if (baseJoinType == Cross) {
+ throw naturalCrossJoinUnsupportedError(ctx)
+ }
+ (NaturalJoin(baseJoinType), None)
+ case None =>
+ (baseJoinType, None)
+ }
+ if (join.LATERAL != null) {
+ if (!Seq(Inner, Cross, LeftOuter).contains(joinType)) {
+ throw unsupportedLateralJoinTypeError(ctx, joinType.toString)
+ }
+ LateralJoin(left, LateralSubquery(plan(join.right)), joinType, condition)
+ } else {
+ Join(left, plan(join.right), joinType, condition, JoinHint.NONE)
+ }
+ }
+ }
+ }
+
+ /**
+ * Add a [[Sample]] to a logical plan.
+ *
+ * This currently supports the following sampling methods:
+ * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows.
+ * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages
+ * are defined as a number between 0 and 100.
+ * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction.
+ */
+ private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // Create a sampled plan if we need one.
+ def sample(fraction: Double): Sample = {
+ // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling
+ // function takes X PERCENT as the input and the range of X is [0, 100], we need to
+ // adjust the fraction.
+ val eps = RandomSampler.roundingEpsilon
+ validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps,
+ s"Sampling fraction ($fraction) must be on interval [0, 1]",
+ ctx)
+ Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)
+ }
+
+ if (ctx.sampleMethod() == null) {
+ throw emptyInputForTableSampleError(ctx)
+ }
+
+ ctx.sampleMethod() match {
+ case ctx: SampleByRowsContext =>
+ Limit(expression(ctx.expression), query)
+
+ case ctx: SampleByPercentileContext =>
+ val fraction = ctx.percentage.getText.toDouble
+ val sign = if (ctx.negativeSign == null) 1 else -1
+ sample(sign * fraction / 100.0d)
+
+ case ctx: SampleByBytesContext =>
+ val bytesStr = ctx.bytes.getText
+ if (bytesStr.matches("[0-9]+[bBkKmMgG]")) {
+ throw tableSampleByBytesUnsupportedError("byteLengthLiteral", ctx)
+ } else {
+ throw invalidByteLengthLiteralError(bytesStr, ctx)
+ }
+
+ case ctx: SampleByBucketContext if ctx.ON() != null =>
+ if (ctx.identifier != null) {
+ throw tableSampleByBytesUnsupportedError(
+ "BUCKET x OUT OF y ON colname", ctx)
+ } else {
+ throw tableSampleByBytesUnsupportedError(
+ "BUCKET x OUT OF y ON function", ctx)
+ }
+
+ case ctx: SampleByBucketContext =>
+ sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble)
+ }
+ }
+
+ /**
+ * Create a logical plan for a sub-query.
+ */
+ override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.query)
+ }
+
+ /**
+ * Create an un-aliased table reference. This is typically used for top-level table references,
+ * for example:
+ * {{{
+ * INSERT INTO db.tbl2
+ * TABLE db.tbl1
+ * }}}
+ */
+ override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) {
+ UnresolvedRelation(visitMultipartIdentifier(ctx.multipartIdentifier))
+ }
+
+ /**
+ * Create a table-valued function call with arguments, e.g. range(1000)
+ */
+ override def visitTableValuedFunction(ctx: TableValuedFunctionContext)
+ : LogicalPlan = withOrigin(ctx) {
+ val func = ctx.functionTable
+ val aliases = if (func.tableAlias.identifierList != null) {
+ visitIdentifierList(func.tableAlias.identifierList)
+ } else {
+ Seq.empty
+ }
+ val name = getFunctionIdentifier(func.functionName)
+ if (name.database.nonEmpty) {
+ operationNotAllowed(s"table valued function cannot specify database name: $name", ctx)
+ }
+
+ val tvf = UnresolvedTableValuedFunction(
+ name, func.expression.asScala.map(expression).toSeq, aliases)
+ tvf.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan)
+ }
+
+ /**
+ * Create an inline table (a virtual table in Hive parlance).
+ */
+ override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) {
+ // Get the backing expressions.
+ val rows = ctx.expression.asScala.map { e =>
+ expression(e) match {
+ // inline table comes in two styles:
+ // style 1: values (1), (2), (3) -- multiple columns are supported
+ // style 2: values 1, 2, 3 -- only a single column is supported here
+ case struct: CreateNamedStruct => struct.valExprs // style 1
+ case child => Seq(child) // style 2
+ }
+ }
+
+ val aliases = if (ctx.tableAlias.identifierList != null) {
+ visitIdentifierList(ctx.tableAlias.identifierList)
+ } else {
+ Seq.tabulate(rows.head.size)(i => s"col${i + 1}")
+ }
+
+ val table = UnresolvedInlineTable(aliases, rows.toSeq)
+ table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a join relation. This is practically the same as
+ * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different
+ * hooks. We could add alias names for output columns, for example:
+ * {{{
+ * SELECT a, b, c, d FROM (src1 s1 INNER JOIN src2 s2 ON s1.id = s2.id) dst(a, b, c, d)
+ * }}}
+ */
+ override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) {
+ val relation = plan(ctx.relation).optionalMap(ctx.sample)(withSample)
+ mayApplyAliasPlan(ctx.tableAlias, relation)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as
+ * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different
+ * hooks. We could add alias names for output columns, for example:
+ * {{{
+ * SELECT col1, col2 FROM testData AS t(col1, col2)
+ * }}}
+ */
+ override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) {
+ val relation = plan(ctx.query).optionalMap(ctx.sample)(withSample)
+ if (ctx.tableAlias.strictIdentifier == null) {
+ // For un-aliased subqueries, use a default alias name that is not likely to conflict with
+ // normal subquery names, so that parent operators can only access the columns in subquery by
+ // unqualified names. Users can still use this special qualifier to access columns if they
+ // know it, but that's not recommended.
+ SubqueryAlias("__auto_generated_subquery_name", relation)
+ } else {
+ mayApplyAliasPlan(ctx.tableAlias, relation)
+ }
+ }
+
+ /**
+ * Create an alias ([[SubqueryAlias]]) for a [[LogicalPlan]].
+ */
+ private def aliasPlan(alias: ParserRuleContext, plan: LogicalPlan): LogicalPlan = {
+ SubqueryAlias(alias.getText, plan)
+ }
+
+ /**
+ * If aliases specified in a FROM clause, create a subquery alias ([[SubqueryAlias]]) and
+ * column aliases for a [[LogicalPlan]].
+ */
+ private def mayApplyAliasPlan(tableAlias: TableAliasContext, plan: LogicalPlan): LogicalPlan = {
+ if (tableAlias.strictIdentifier != null) {
+ val alias = tableAlias.strictIdentifier.getText
+ if (tableAlias.identifierList != null) {
+ val columnNames = visitIdentifierList(tableAlias.identifierList)
+ SubqueryAlias(alias, UnresolvedSubqueryColumnAliases(columnNames, plan))
+ } else {
+ SubqueryAlias(alias, plan)
+ }
+ } else {
+ plan
+ }
+ }
+
+ /**
+ * Create a Sequence of Strings for a parenthesis enclosed alias list.
+ */
+ override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) {
+ visitIdentifierSeq(ctx.identifierSeq)
+ }
+
+ /**
+ * Create a Sequence of Strings for an identifier list.
+ */
+ override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) {
+ ctx.ident.asScala.map(_.getText).toSeq
+ }
+
+ /* ********************************************************************************************
+ * Table Identifier parsing
+ * ******************************************************************************************** */
+ /**
+ * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern.
+ */
+ override def visitTableIdentifier(
+ ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) {
+ TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText))
+ }
+
+ /**
+ * Create a [[FunctionIdentifier]] from a 'functionName' or 'databaseName'.'functionName' pattern.
+ */
+ override def visitFunctionIdentifier(
+ ctx: FunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) {
+ FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText))
+ }
+
+ /**
+ * Create a multi-part identifier.
+ */
+ override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] =
+ withOrigin(ctx) {
+ ctx.parts.asScala.map(_.getText).toSeq
+ }
+
+ /* ********************************************************************************************
+ * Expression parsing
+ * ******************************************************************************************** */
+ /**
+ * Create an expression from the given context. This method just passes the context on to the
+ * visitor and only takes care of typing (We assume that the visitor returns an Expression here).
+ */
+ protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx)
+
+ /**
+ * Create sequence of expressions from the given sequence of contexts.
+ */
+ private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = {
+ trees.asScala.map(expression).toSeq
+ }
+
+ /**
+ * Create a star (i.e. all) expression; this selects all elements (in the specified object).
+ * Both un-targeted (global) and targeted aliases are supported.
+ */
+ override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) {
+ UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText).toSeq))
+ }
+
+ /**
+ * Create an aliased expression if an alias is specified. Both single and multi-aliases are
+ * supported.
+ */
+ override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.expression)
+ if (ctx.name != null) {
+ Alias(e, ctx.name.getText)()
+ } else if (ctx.identifierList != null) {
+ MultiAlias(e, visitIdentifierList(ctx.identifierList))
+ } else {
+ e
+ }
+ }
+
+ /**
+ * Combine a number of boolean expressions into a balanced expression tree. These expressions are
+ * either combined by a logical [[And]] or a logical [[Or]].
+ *
+ * A balanced binary tree is created because regular left recursive trees cause considerable
+ * performance degradations and can cause stack overflows.
+ */
+ override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) {
+ val expressionType = ctx.operator.getType
+ val expressionCombiner = expressionType match {
+ case HoodieSqlBaseParser.AND => And.apply _
+ case HoodieSqlBaseParser.OR => Or.apply _
+ }
+
+ // Collect all similar left hand contexts.
+ val contexts = ArrayBuffer(ctx.right)
+ var current = ctx.left
+ def collectContexts: Boolean = current match {
+ case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType =>
+ contexts += lbc.right
+ current = lbc.left
+ true
+ case _ =>
+ contexts += current
+ false
+ }
+ while (collectContexts) {
+ // No body - all updates take place in the collectContexts.
+ }
+
+ // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them
+ // into expressions.
+ val expressions = contexts.reverseMap(expression)
+
+ // Create a balanced tree.
+ def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match {
+ case 0 =>
+ expressions(low)
+ case 1 =>
+ expressionCombiner(expressions(low), expressions(high))
+ case x =>
+ val mid = low + x / 2
+ expressionCombiner(
+ reduceToExpressionTree(low, mid),
+ reduceToExpressionTree(mid + 1, high))
+ }
+ reduceToExpressionTree(0, expressions.size - 1)
+ }
+
+ /**
+ * Invert a boolean expression.
+ */
+ override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) {
+ Not(expression(ctx.booleanExpression()))
+ }
+
+ /**
+ * Create a filtering correlated sub-query (EXISTS).
+ */
+ override def visitExists(ctx: ExistsContext): Expression = {
+ Exists(plan(ctx.query))
+ }
+
+ /**
+ * Create a comparison expression. This compares two expressions. The following comparison
+ * operators are supported:
+ * - Equal: '=' or '=='
+ * - Null-safe Equal: '<=>'
+ * - Not Equal: '<>' or '!='
+ * - Less than: '<'
+ * - Less then or Equal: '<='
+ * - Greater than: '>'
+ * - Greater then or Equal: '>='
+ */
+ override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) {
+ val left = expression(ctx.left)
+ val right = expression(ctx.right)
+ val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode]
+ operator.getSymbol.getType match {
+ case HoodieSqlBaseParser.EQ =>
+ EqualTo(left, right)
+ case HoodieSqlBaseParser.NSEQ =>
+ EqualNullSafe(left, right)
+ case HoodieSqlBaseParser.NEQ | HoodieSqlBaseParser.NEQJ =>
+ Not(EqualTo(left, right))
+ case HoodieSqlBaseParser.LT =>
+ LessThan(left, right)
+ case HoodieSqlBaseParser.LTE =>
+ LessThanOrEqual(left, right)
+ case HoodieSqlBaseParser.GT =>
+ GreaterThan(left, right)
+ case HoodieSqlBaseParser.GTE =>
+ GreaterThanOrEqual(left, right)
+ }
+ }
+
+ /**
+ * Create a predicated expression. A predicated expression is a normal expression with a
+ * predicate attached to it, for example:
+ * {{{
+ * a + 1 IS NULL
+ * }}}
+ */
+ override def visitPredicated(ctx: PredicatedContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.valueExpression)
+ if (ctx.predicate != null) {
+ withPredicate(e, ctx.predicate)
+ } else {
+ e
+ }
+ }
+
+ /**
+ * Add a predicate to the given expression. Supported expressions are:
+ * - (NOT) BETWEEN
+ * - (NOT) IN
+ * - (NOT) LIKE (ANY | SOME | ALL)
+ * - (NOT) RLIKE
+ * - IS (NOT) NULL.
+ * - IS (NOT) (TRUE | FALSE | UNKNOWN)
+ * - IS (NOT) DISTINCT FROM
+ */
+ private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) {
+ // Invert a predicate if it has a valid NOT clause.
+ def invertIfNotDefined(e: Expression): Expression = ctx.NOT match {
+ case null => e
+ case not => Not(e)
+ }
+
+ def getValueExpressions(e: Expression): Seq[Expression] = e match {
+ case c: CreateNamedStruct => c.valExprs
+ case other => Seq(other)
+ }
+
+ // Create the predicate.
+ ctx.kind.getType match {
+ case HoodieSqlBaseParser.BETWEEN =>
+ // BETWEEN is translated to lower <= e && e <= upper
+ invertIfNotDefined(And(
+ GreaterThanOrEqual(e, expression(ctx.lower)),
+ LessThanOrEqual(e, expression(ctx.upper))))
+ case HoodieSqlBaseParser.IN if ctx.query != null =>
+ invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query))))
+ case HoodieSqlBaseParser.IN =>
+ invertIfNotDefined(In(e, ctx.expression.asScala.map(expression).toSeq))
+ case HoodieSqlBaseParser.LIKE =>
+ Option(ctx.quantifier).map(_.getType) match {
+ case Some(HoodieSqlBaseParser.ANY) | Some(HoodieSqlBaseParser.SOME) =>
+ validate(!ctx.expression.isEmpty, "Expected something between '(' and ')'.", ctx)
+ val expressions = expressionList(ctx.expression)
+ if (expressions.forall(_.foldable) && expressions.forall(_.dataType == StringType)) {
+ // If there are many pattern expressions, will throw StackOverflowError.
+ // So we use LikeAny or NotLikeAny instead.
+ val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String])
+ ctx.NOT match {
+ case null => LikeAny(e, patterns)
+ case _ => NotLikeAny(e, patterns)
+ }
+ } else {
+ ctx.expression.asScala.map(expression)
+ .map(p => invertIfNotDefined(new Like(e, p))).toSeq.reduceLeft(Or)
+ }
+ case Some(HoodieSqlBaseParser.ALL) =>
+ validate(!ctx.expression.isEmpty, "Expected something between '(' and ')'.", ctx)
+ val expressions = expressionList(ctx.expression)
+ if (expressions.forall(_.foldable) && expressions.forall(_.dataType == StringType)) {
+ // If there are many pattern expressions, will throw StackOverflowError.
+ // So we use LikeAll or NotLikeAll instead.
+ val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String])
+ ctx.NOT match {
+ case null => LikeAll(e, patterns)
+ case _ => NotLikeAll(e, patterns)
+ }
+ } else {
+ ctx.expression.asScala.map(expression)
+ .map(p => invertIfNotDefined(new Like(e, p))).toSeq.reduceLeft(And)
+ }
+ case _ =>
+ val escapeChar = Option(ctx.escapeChar).map(string).map { str =>
+ if (str.length != 1) {
+ throw invalidEscapeStringError(ctx)
+ }
+ str.charAt(0)
+ }.getOrElse('\\')
+ invertIfNotDefined(Like(e, expression(ctx.pattern), escapeChar))
+ }
+ case HoodieSqlBaseParser.RLIKE =>
+ invertIfNotDefined(RLike(e, expression(ctx.pattern)))
+ case HoodieSqlBaseParser.NULL if ctx.NOT != null =>
+ IsNotNull(e)
+ case HoodieSqlBaseParser.NULL =>
+ IsNull(e)
+ case HoodieSqlBaseParser.TRUE => ctx.NOT match {
+ case null => EqualNullSafe(e, Literal(true))
+ case _ => Not(EqualNullSafe(e, Literal(true)))
+ }
+ case HoodieSqlBaseParser.FALSE => ctx.NOT match {
+ case null => EqualNullSafe(e, Literal(false))
+ case _ => Not(EqualNullSafe(e, Literal(false)))
+ }
+ case HoodieSqlBaseParser.UNKNOWN => ctx.NOT match {
+ case null => IsUnknown(e)
+ case _ => IsNotUnknown(e)
+ }
+ case HoodieSqlBaseParser.DISTINCT if ctx.NOT != null =>
+ EqualNullSafe(e, expression(ctx.right))
+ case HoodieSqlBaseParser.DISTINCT =>
+ Not(EqualNullSafe(e, expression(ctx.right)))
+ }
+ }
+
+ /**
+ * Create a binary arithmetic expression. The following arithmetic operators are supported:
+ * - Multiplication: '*'
+ * - Division: '/'
+ * - Hive Long Division: 'DIV'
+ * - Modulo: '%'
+ * - Addition: '+'
+ * - Subtraction: '-'
+ * - Binary AND: '&'
+ * - Binary XOR
+ * - Binary OR: '|'
+ */
+ override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) {
+ val left = expression(ctx.left)
+ val right = expression(ctx.right)
+ ctx.operator.getType match {
+ case HoodieSqlBaseParser.ASTERISK =>
+ Multiply(left, right)
+ case HoodieSqlBaseParser.SLASH =>
+ Divide(left, right)
+ case HoodieSqlBaseParser.PERCENT =>
+ Remainder(left, right)
+ case HoodieSqlBaseParser.DIV =>
+ IntegralDivide(left, right)
+ case HoodieSqlBaseParser.PLUS =>
+ Add(left, right)
+ case HoodieSqlBaseParser.MINUS =>
+ Subtract(left, right)
+ case HoodieSqlBaseParser.CONCAT_PIPE =>
+ Concat(left :: right :: Nil)
+ case HoodieSqlBaseParser.AMPERSAND =>
+ BitwiseAnd(left, right)
+ case HoodieSqlBaseParser.HAT =>
+ BitwiseXor(left, right)
+ case HoodieSqlBaseParser.PIPE =>
+ BitwiseOr(left, right)
+ }
+ }
+
+ /**
+ * Create a unary arithmetic expression. The following arithmetic operators are supported:
+ * - Plus: '+'
+ * - Minus: '-'
+ * - Bitwise Not: '~'
+ */
+ override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) {
+ val value = expression(ctx.valueExpression)
+ ctx.operator.getType match {
+ case HoodieSqlBaseParser.PLUS =>
+ UnaryPositive(value)
+ case HoodieSqlBaseParser.MINUS =>
+ UnaryMinus(value)
+ case HoodieSqlBaseParser.TILDE =>
+ BitwiseNot(value)
+ }
+ }
+
+ override def visitCurrentLike(ctx: CurrentLikeContext): Expression = withOrigin(ctx) {
+ if (conf.ansiEnabled) {
+ ctx.name.getType match {
+ case HoodieSqlBaseParser.CURRENT_DATE =>
+ CurrentDate()
+ case HoodieSqlBaseParser.CURRENT_TIMESTAMP =>
+ CurrentTimestamp()
+ case HoodieSqlBaseParser.CURRENT_USER =>
+ CurrentUser()
+ }
+ } else {
+ // If the parser is not in ansi mode, we should return `UnresolvedAttribute`, in case there
+ // are columns named `CURRENT_DATE` or `CURRENT_TIMESTAMP`.
+ UnresolvedAttribute.quoted(ctx.name.getText)
+ }
+ }
+
+ /**
+ * Create a [[Cast]] expression.
+ */
+ override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) {
+ val rawDataType = typedVisit[DataType](ctx.dataType())
+ val dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType)
+ val cast = ctx.name.getType match {
+ case HoodieSqlBaseParser.CAST =>
+ Cast(expression(ctx.expression), dataType)
+
+ case HoodieSqlBaseParser.TRY_CAST =>
+ TryCast(expression(ctx.expression), dataType)
+ }
+ cast.setTagValue(Cast.USER_SPECIFIED_CAST, true)
+ cast
+ }
+
+ /**
+ * Create a [[CreateStruct]] expression.
+ */
+ override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) {
+ CreateStruct.create(ctx.argument.asScala.map(expression).toSeq)
+ }
+
+ /**
+ * Create a [[First]] expression.
+ */
+ override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) {
+ val ignoreNullsExpr = ctx.IGNORE != null
+ First(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression()
+ }
+
+ /**
+ * Create a [[Last]] expression.
+ */
+ override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) {
+ val ignoreNullsExpr = ctx.IGNORE != null
+ Last(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression()
+ }
+
+ /**
+ * Create a Position expression.
+ */
+ override def visitPosition(ctx: PositionContext): Expression = withOrigin(ctx) {
+ new StringLocate(expression(ctx.substr), expression(ctx.str))
+ }
+
+ /**
+ * Create a Extract expression.
+ */
+ override def visitExtract(ctx: ExtractContext): Expression = withOrigin(ctx) {
+ val arguments = Seq(Literal(ctx.field.getText), expression(ctx.source))
+ UnresolvedFunction("extract", arguments, isDistinct = false)
+ }
+
+ /**
+ * Create a Substring/Substr expression.
+ */
+ override def visitSubstring(ctx: SubstringContext): Expression = withOrigin(ctx) {
+ if (ctx.len != null) {
+ Substring(expression(ctx.str), expression(ctx.pos), expression(ctx.len))
+ } else {
+ new Substring(expression(ctx.str), expression(ctx.pos))
+ }
+ }
+
+ /**
+ * Create a Trim expression.
+ */
+ override def visitTrim(ctx: TrimContext): Expression = withOrigin(ctx) {
+ val srcStr = expression(ctx.srcStr)
+ val trimStr = Option(ctx.trimStr).map(expression)
+ Option(ctx.trimOption).map(_.getType).getOrElse(HoodieSqlBaseParser.BOTH) match {
+ case HoodieSqlBaseParser.BOTH =>
+ StringTrim(srcStr, trimStr)
+ case HoodieSqlBaseParser.LEADING =>
+ StringTrimLeft(srcStr, trimStr)
+ case HoodieSqlBaseParser.TRAILING =>
+ StringTrimRight(srcStr, trimStr)
+ case other =>
+ throw trimOptionUnsupportedError(other, ctx)
+ }
+ }
+
+ /**
+ * Create a Overlay expression.
+ */
+ override def visitOverlay(ctx: OverlayContext): Expression = withOrigin(ctx) {
+ val input = expression(ctx.input)
+ val replace = expression(ctx.replace)
+ val position = expression(ctx.position)
+ val lengthOpt = Option(ctx.length).map(expression)
+ lengthOpt match {
+ case Some(length) => Overlay(input, replace, position, length)
+ case None => new Overlay(input, replace, position)
+ }
+ }
+
+ /**
+ * Create a (windowed) Function expression.
+ */
+ override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) {
+ // Create the function call.
+ val name = ctx.functionName.getText
+ val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null)
+ // Call `toSeq`, otherwise `ctx.argument.asScala.map(expression)` is `Buffer` in Scala 2.13
+ val arguments = ctx.argument.asScala.map(expression).toSeq match {
+ case Seq(UnresolvedStar(None))
+ if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct =>
+ // Transform COUNT(*) into COUNT(1).
+ Seq(Literal(1))
+ case expressions =>
+ expressions
+ }
+ val filter = Option(ctx.where).map(expression(_))
+ val ignoreNulls =
+ Option(ctx.nullsOption).map(_.getType == HoodieSqlBaseParser.IGNORE).getOrElse(false)
+ val function = UnresolvedFunction(
+ getFunctionMultiparts(ctx.functionName), arguments, isDistinct, filter, ignoreNulls)
+
+ // Check if the function is evaluated in a windowed context.
+ ctx.windowSpec match {
+ case spec: WindowRefContext =>
+ UnresolvedWindowExpression(function, visitWindowRef(spec))
+ case spec: WindowDefContext =>
+ WindowExpression(function, visitWindowDef(spec))
+ case _ => function
+ }
+ }
+
+ /**
+ * Create a function database (optional) and name pair.
+ */
+ protected def visitFunctionName(ctx: QualifiedNameContext): FunctionIdentifier = {
+ visitFunctionName(ctx, ctx.identifier().asScala.map(_.getText).toSeq)
+ }
+
+ /**
+ * Create a function database (optional) and name pair.
+ */
+ private def visitFunctionName(ctx: ParserRuleContext, texts: Seq[String]): FunctionIdentifier = {
+ texts match {
+ case Seq(db, fn) => FunctionIdentifier(fn, Option(db))
+ case Seq(fn) => FunctionIdentifier(fn, None)
+ case other =>
+ throw functionNameUnsupportedError(texts.mkString("."), ctx)
+ }
+ }
+
+ /**
+ * Get a function identifier consist by database (optional) and name.
+ */
+ protected def getFunctionIdentifier(ctx: FunctionNameContext): FunctionIdentifier = {
+ if (ctx.qualifiedName != null) {
+ visitFunctionName(ctx.qualifiedName)
+ } else {
+ FunctionIdentifier(ctx.getText, None)
+ }
+ }
+
+ protected def getFunctionMultiparts(ctx: FunctionNameContext): Seq[String] = {
+ if (ctx.qualifiedName != null) {
+ ctx.qualifiedName().identifier().asScala.map(_.getText).toSeq
+ } else {
+ Seq(ctx.getText)
+ }
+ }
+
+ /**
+ * Create an [[LambdaFunction]].
+ */
+ override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) {
+ val arguments = ctx.identifier().asScala.map { name =>
+ UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts)
+ }
+ val function = expression(ctx.expression).transformUp {
+ case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts)
+ }
+ LambdaFunction(function, arguments.toSeq)
+ }
+
+ /**
+ * Create a reference to a window frame, i.e. [[WindowSpecReference]].
+ */
+ override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) {
+ WindowSpecReference(ctx.name.getText)
+ }
+
+ /**
+ * Create a window definition, i.e. [[WindowSpecDefinition]].
+ */
+ override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) {
+ // CLUSTER BY ... | PARTITION BY ... ORDER BY ...
+ val partition = ctx.partition.asScala.map(expression)
+ val order = ctx.sortItem.asScala.map(visitSortItem)
+
+ // RANGE/ROWS BETWEEN ...
+ val frameSpecOption = Option(ctx.windowFrame).map { frame =>
+ val frameType = frame.frameType.getType match {
+ case HoodieSqlBaseParser.RANGE => RangeFrame
+ case HoodieSqlBaseParser.ROWS => RowFrame
+ }
+
+ SpecifiedWindowFrame(
+ frameType,
+ visitFrameBound(frame.start),
+ Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow))
+ }
+
+ WindowSpecDefinition(
+ partition.toSeq,
+ order.toSeq,
+ frameSpecOption.getOrElse(UnspecifiedFrame))
+ }
+
+ /**
+ * Create or resolve a frame boundary expressions.
+ */
+ override def visitFrameBound(ctx: FrameBoundContext): Expression = withOrigin(ctx) {
+ def value: Expression = {
+ val e = expression(ctx.expression)
+ validate(e.resolved && e.foldable, "Frame bound value must be a literal.", ctx)
+ e
+ }
+
+ ctx.boundType.getType match {
+ case HoodieSqlBaseParser.PRECEDING if ctx.UNBOUNDED != null =>
+ UnboundedPreceding
+ case HoodieSqlBaseParser.PRECEDING =>
+ UnaryMinus(value)
+ case HoodieSqlBaseParser.CURRENT =>
+ CurrentRow
+ case HoodieSqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null =>
+ UnboundedFollowing
+ case HoodieSqlBaseParser.FOLLOWING =>
+ value
+ }
+ }
+
+ /**
+ * Create a [[CreateStruct]] expression.
+ */
+ override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) {
+ CreateStruct(ctx.namedExpression().asScala.map(expression).toSeq)
+ }
+
+ /**
+ * Create a [[ScalarSubquery]] expression.
+ */
+ override def visitSubqueryExpression(
+ ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) {
+ ScalarSubquery(plan(ctx.query))
+ }
+
+ /**
+ * Create a value based [[CaseWhen]] expression. This has the following SQL form:
+ * {{{
+ * CASE [expression]
+ * WHEN [value] THEN [expression]
+ * ...
+ * ELSE [expression]
+ * END
+ * }}}
+ */
+ override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.value)
+ val branches = ctx.whenClause.asScala.map { wCtx =>
+ (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result))
+ }
+ CaseWhen(branches.toSeq, Option(ctx.elseExpression).map(expression))
+ }
+
+ /**
+ * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax:
+ * {{{
+ * CASE
+ * WHEN [predicate] THEN [expression]
+ * ...
+ * ELSE [expression]
+ * END
+ * }}}
+ *
+ * @param ctx the parse tree
+ * */
+ override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) {
+ val branches = ctx.whenClause.asScala.map { wCtx =>
+ (expression(wCtx.condition), expression(wCtx.result))
+ }
+ CaseWhen(branches.toSeq, Option(ctx.elseExpression).map(expression))
+ }
+
+ /**
+ * Currently only regex in expressions of SELECT statements are supported; in other
+ * places, e.g., where `(a)?+.+` = 2, regex are not meaningful.
+ */
+ private def canApplyRegex(ctx: ParserRuleContext): Boolean = withOrigin(ctx) {
+ var parent = ctx.getParent
+ var rtn = false
+ while (parent != null) {
+ if (parent.isInstanceOf[NamedExpressionContext]) {
+ rtn = true
+ }
+ parent = parent.getParent
+ }
+ rtn
+ }
+
+ /**
+ * Create a dereference expression. The return type depends on the type of the parent.
+ * If the parent is an [[UnresolvedAttribute]], it can be a [[UnresolvedAttribute]] or
+ * a [[UnresolvedRegex]] for regex quoted in ``; if the parent is some other expression,
+ * it can be [[UnresolvedExtractValue]].
+ */
+ override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) {
+ val attr = ctx.fieldName.getText
+ expression(ctx.base) match {
+ case unresolved_attr @ UnresolvedAttribute(nameParts) =>
+ ctx.fieldName.getStart.getText match {
+ case escapedIdentifier(columnNameRegex)
+ if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) =>
+ UnresolvedRegex(columnNameRegex, Some(unresolved_attr.name),
+ conf.caseSensitiveAnalysis)
+ case _ =>
+ UnresolvedAttribute(nameParts :+ attr)
+ }
+ case e =>
+ UnresolvedExtractValue(e, Literal(attr))
+ }
+ }
+
+ /**
+ * Create an [[UnresolvedAttribute]] expression or a [[UnresolvedRegex]] if it is a regex
+ * quoted in ``
+ */
+ override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) {
+ ctx.getStart.getText match {
+ case escapedIdentifier(columnNameRegex)
+ if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) =>
+ UnresolvedRegex(columnNameRegex, None, conf.caseSensitiveAnalysis)
+ case _ =>
+ UnresolvedAttribute.quoted(ctx.getText)
+ }
+
+ }
+
+ /**
+ * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array.
+ */
+ override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) {
+ UnresolvedExtractValue(expression(ctx.value), expression(ctx.index))
+ }
+
+ /**
+ * Create an expression for an expression between parentheses. This is need because the ANTLR
+ * visitor cannot automatically convert the nested context into an expression.
+ */
+ override def visitParenthesizedExpression(
+ ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) {
+ expression(ctx.expression)
+ }
+
+ /**
+ * Create a [[SortOrder]] expression.
+ */
+ override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) {
+ val direction = if (ctx.DESC != null) {
+ Descending
+ } else {
+ Ascending
+ }
+ val nullOrdering = if (ctx.FIRST != null) {
+ NullsFirst
+ } else if (ctx.LAST != null) {
+ NullsLast
+ } else {
+ direction.defaultNullOrdering
+ }
+ SortOrder(expression(ctx.expression), direction, nullOrdering, Seq.empty)
+ }
+
+ /**
+ * Create a typed Literal expression. A typed literal has the following SQL syntax:
+ * {{{
+ * [TYPE] '[VALUE]'
+ * }}}
+ * Currently Date, Timestamp, Interval and Binary typed literals are supported.
+ */
+ override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) {
+ val value = string(ctx.STRING)
+ val valueType = ctx.identifier.getText.toUpperCase(Locale.ROOT)
+
+ def toLiteral[T](f: UTF8String => Option[T], t: DataType): Literal = {
+ f(UTF8String.fromString(value)).map(Literal(_, t)).getOrElse {
+ throw cannotParseValueTypeError(valueType, value, ctx)
+ }
+ }
+
+ def constructTimestampLTZLiteral(value: String): Literal = {
+ val zoneId = getZoneId(conf.sessionLocalTimeZone)
+ val specialTs = convertSpecialTimestamp(value, zoneId).map(Literal(_, TimestampType))
+ specialTs.getOrElse(toLiteral(stringToTimestamp(_, zoneId), TimestampType))
+ }
+
+ try {
+ valueType match {
+ case "DATE" =>
+ val zoneId = getZoneId(conf.sessionLocalTimeZone)
+ val specialDate = convertSpecialDate(value, zoneId).map(Literal(_, DateType))
+ specialDate.getOrElse(toLiteral(stringToDate, DateType))
+ // SPARK-36227: Remove TimestampNTZ type support in Spark 3.2 with minimal code changes.
+ case "TIMESTAMP_NTZ" if isTesting =>
+ convertSpecialTimestampNTZ(value, getZoneId(conf.sessionLocalTimeZone))
+ .map(Literal(_, TimestampNTZType))
+ .getOrElse(toLiteral(stringToTimestampWithoutTimeZone, TimestampNTZType))
+ case "TIMESTAMP_LTZ" if isTesting =>
+ constructTimestampLTZLiteral(value)
+ case "TIMESTAMP" =>
+ SQLConf.get.timestampType match {
+ case TimestampNTZType =>
+ convertSpecialTimestampNTZ(value, getZoneId(conf.sessionLocalTimeZone))
+ .map(Literal(_, TimestampNTZType))
+ .getOrElse {
+ val containsTimeZonePart =
+ DateTimeUtils.parseTimestampString(UTF8String.fromString(value))._2.isDefined
+ // If the input string contains time zone part, return a timestamp with local time
+ // zone literal.
+ if (containsTimeZonePart) {
+ constructTimestampLTZLiteral(value)
+ } else {
+ toLiteral(stringToTimestampWithoutTimeZone, TimestampNTZType)
+ }
+ }
+
+ case TimestampType =>
+ constructTimestampLTZLiteral(value)
+ }
+
+ case "INTERVAL" =>
+ val interval = try {
+ IntervalUtils.stringToInterval(UTF8String.fromString(value))
+ } catch {
+ case e: IllegalArgumentException =>
+ val ex = cannotParseIntervalValueError(value, ctx)
+ ex.setStackTrace(e.getStackTrace)
+ throw ex
+ }
+ if (!conf.legacyIntervalEnabled) {
+ val units = value
+ .split("\\s")
+ .map(_.toLowerCase(Locale.ROOT).stripSuffix("s"))
+ .filter(s => s != "interval" && s.matches("[a-z]+"))
+ constructMultiUnitsIntervalLiteral(ctx, interval, units)
+ } else {
+ Literal(interval, CalendarIntervalType)
+ }
+ case "X" =>
+ val padding = if (value.length % 2 != 0) "0" else ""
+ Literal(DatatypeConverter.parseHexBinary(padding + value))
+ case other =>
+ throw literalValueTypeUnsupportedError(other, ctx)
+ }
+ } catch {
+ case e: IllegalArgumentException =>
+ throw parsingValueTypeError(e, valueType, ctx)
+ }
+ }
+
+ /**
+ * Create a NULL literal expression.
+ */
+ override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) {
+ Literal(null)
+ }
+
+ /**
+ * Create a Boolean literal expression.
+ */
+ override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) {
+ if (ctx.getText.toBoolean) {
+ Literal.TrueLiteral
+ } else {
+ Literal.FalseLiteral
+ }
+ }
+
+ /**
+ * Create an integral literal expression. The code selects the most narrow integral type
+ * possible, either a BigDecimal, a Long or an Integer is returned.
+ */
+ override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) {
+ BigDecimal(ctx.getText) match {
+ case v if v.isValidInt =>
+ Literal(v.intValue)
+ case v if v.isValidLong =>
+ Literal(v.longValue)
+ case v => Literal(v.underlying())
+ }
+ }
+
+ /**
+ * Create a decimal literal for a regular decimal number.
+ */
+ override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) {
+ Literal(BigDecimal(ctx.getText).underlying())
+ }
+
+ /**
+ * Create a decimal literal for a regular decimal number or a scientific decimal number.
+ */
+ override def visitLegacyDecimalLiteral(
+ ctx: LegacyDecimalLiteralContext): Literal = withOrigin(ctx) {
+ Literal(BigDecimal(ctx.getText).underlying())
+ }
+
+ /**
+ * Create a double literal for number with an exponent, e.g. 1E-30
+ */
+ override def visitExponentLiteral(ctx: ExponentLiteralContext): Literal = {
+ numericLiteral(ctx, ctx.getText, /* exponent values don't have a suffix */
+ Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble)
+ }
+
+ /** Create a numeric literal expression. */
+ private def numericLiteral(
+ ctx: NumberContext,
+ rawStrippedQualifier: String,
+ minValue: BigDecimal,
+ maxValue: BigDecimal,
+ typeName: String)(converter: String => Any): Literal = withOrigin(ctx) {
+ try {
+ val rawBigDecimal = BigDecimal(rawStrippedQualifier)
+ if (rawBigDecimal < minValue || rawBigDecimal > maxValue) {
+ throw invalidNumericLiteralRangeError(
+ rawStrippedQualifier, minValue, maxValue, typeName, ctx)
+ }
+ Literal(converter(rawStrippedQualifier))
+ } catch {
+ case e: NumberFormatException =>
+ throw new ParseException(e.getMessage, ctx)
+ }
+ }
+
+ /**
+ * Create a Byte Literal expression.
+ */
+ override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = {
+ val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1)
+ numericLiteral(ctx, rawStrippedQualifier,
+ Byte.MinValue, Byte.MaxValue, ByteType.simpleString)(_.toByte)
+ }
+
+ /**
+ * Create a Short Literal expression.
+ */
+ override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = {
+ val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1)
+ numericLiteral(ctx, rawStrippedQualifier,
+ Short.MinValue, Short.MaxValue, ShortType.simpleString)(_.toShort)
+ }
+
+ /**
+ * Create a Long Literal expression.
+ */
+ override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = {
+ val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1)
+ numericLiteral(ctx, rawStrippedQualifier,
+ Long.MinValue, Long.MaxValue, LongType.simpleString)(_.toLong)
+ }
+
+ /**
+ * Create a Float Literal expression.
+ */
+ override def visitFloatLiteral(ctx: FloatLiteralContext): Literal = {
+ val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1)
+ numericLiteral(ctx, rawStrippedQualifier,
+ Float.MinValue, Float.MaxValue, FloatType.simpleString)(_.toFloat)
+ }
+
+ /**
+ * Create a Double Literal expression.
+ */
+ override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = {
+ val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1)
+ numericLiteral(ctx, rawStrippedQualifier,
+ Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble)
+ }
+
+ /**
+ * Create a BigDecimal Literal expression.
+ */
+ override def visitBigDecimalLiteral(ctx: BigDecimalLiteralContext): Literal = {
+ val raw = ctx.getText.substring(0, ctx.getText.length - 2)
+ try {
+ Literal(BigDecimal(raw).underlying())
+ } catch {
+ case e: AnalysisException =>
+ throw new ParseException(e.message, ctx)
+ }
+ }
+
+ /**
+ * Create a String literal expression.
+ */
+ override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) {
+ Literal(createString(ctx))
+ }
+
+ /**
+ * Create a String from a string literal context. This supports multiple consecutive string
+ * literals, these are concatenated, for example this expression "'hello' 'world'" will be
+ * converted into "helloworld".
+ *
+ * Special characters can be escaped by using Hive/C-style escaping.
+ */
+ private def createString(ctx: StringLiteralContext): String = {
+ if (conf.escapedStringLiterals) {
+ ctx.STRING().asScala.map(stringWithoutUnescape).mkString
+ } else {
+ ctx.STRING().asScala.map(string).mkString
+ }
+ }
+
+ /**
+ * Create an [[UnresolvedRelation]] from a multi-part identifier context.
+ */
+ private def createUnresolvedRelation(
+ ctx: MultipartIdentifierContext): UnresolvedRelation = withOrigin(ctx) {
+ UnresolvedRelation(visitMultipartIdentifier(ctx))
+ }
+
+ /**
+ * Create an [[UnresolvedTable]] from a multi-part identifier context.
+ */
+ private def createUnresolvedTable(
+ ctx: MultipartIdentifierContext,
+ commandName: String,
+ relationTypeMismatchHint: Option[String] = None): UnresolvedTable = withOrigin(ctx) {
+ UnresolvedTable(visitMultipartIdentifier(ctx), commandName, relationTypeMismatchHint)
+ }
+
+ /**
+ * Create an [[UnresolvedView]] from a multi-part identifier context.
+ */
+ private def createUnresolvedView(
+ ctx: MultipartIdentifierContext,
+ commandName: String,
+ allowTemp: Boolean = true,
+ relationTypeMismatchHint: Option[String] = None): UnresolvedView = withOrigin(ctx) {
+ UnresolvedView(visitMultipartIdentifier(ctx), commandName, allowTemp, relationTypeMismatchHint)
+ }
+
+ /**
+ * Create an [[UnresolvedTableOrView]] from a multi-part identifier context.
+ */
+ private def createUnresolvedTableOrView(
+ ctx: MultipartIdentifierContext,
+ commandName: String,
+ allowTempView: Boolean = true): UnresolvedTableOrView = withOrigin(ctx) {
+ UnresolvedTableOrView(visitMultipartIdentifier(ctx), commandName, allowTempView)
+ }
+
+ /**
+ * Construct an [[Literal]] from [[CalendarInterval]] and
+ * units represented as a [[Seq]] of [[String]].
+ */
+ private def constructMultiUnitsIntervalLiteral(
+ ctx: ParserRuleContext,
+ calendarInterval: CalendarInterval,
+ units: Seq[String]): Literal = {
+ var yearMonthFields = Set.empty[Byte]
+ var dayTimeFields = Set.empty[Byte]
+ for (unit <- units) {
+ if (YearMonthIntervalType.stringToField.contains(unit)) {
+ yearMonthFields += YearMonthIntervalType.stringToField(unit)
+ } else if (DayTimeIntervalType.stringToField.contains(unit)) {
+ dayTimeFields += DayTimeIntervalType.stringToField(unit)
+ } else if (unit == "week") {
+ dayTimeFields += DayTimeIntervalType.DAY
+ } else {
+ assert(unit == "millisecond" || unit == "microsecond")
+ dayTimeFields += DayTimeIntervalType.SECOND
+ }
+ }
+ if (yearMonthFields.nonEmpty) {
+ if (dayTimeFields.nonEmpty) {
+ val literalStr = source(ctx)
+ throw mixedIntervalUnitsError(literalStr, ctx)
+ }
+ Literal(
+ calendarInterval.months,
+ YearMonthIntervalType(yearMonthFields.min, yearMonthFields.max)
+ )
+ } else {
+ Literal(
+ IntervalUtils.getDuration(calendarInterval, TimeUnit.MICROSECONDS),
+ DayTimeIntervalType(dayTimeFields.min, dayTimeFields.max))
+ }
+ }
+
+ /**
+ * Create a [[CalendarInterval]] or ANSI interval literal expression.
+ * Two syntaxes are supported:
+ * - multiple unit value pairs, for instance: interval 2 months 2 days.
+ * - from-to unit, for instance: interval '1-2' year to month.
+ */
+ override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) {
+ val calendarInterval = parseIntervalLiteral(ctx)
+ if (ctx.errorCapturingUnitToUnitInterval != null && !conf.legacyIntervalEnabled) {
+ // Check the `to` unit to distinguish year-month and day-time intervals because
+ // `CalendarInterval` doesn't have enough info. For instance, new CalendarInterval(0, 0, 0)
+ // can be derived from INTERVAL '0-0' YEAR TO MONTH as well as from
+ // INTERVAL '0 00:00:00' DAY TO SECOND.
+ val fromUnit =
+ ctx.errorCapturingUnitToUnitInterval.body.from.getText.toLowerCase(Locale.ROOT)
+ val toUnit = ctx.errorCapturingUnitToUnitInterval.body.to.getText.toLowerCase(Locale.ROOT)
+ if (toUnit == "month") {
+ assert(calendarInterval.days == 0 && calendarInterval.microseconds == 0)
+ val start = YearMonthIntervalType.stringToField(fromUnit)
+ Literal(calendarInterval.months, YearMonthIntervalType(start, YearMonthIntervalType.MONTH))
+ } else {
+ assert(calendarInterval.months == 0)
+ val micros = IntervalUtils.getDuration(calendarInterval, TimeUnit.MICROSECONDS)
+ val start = DayTimeIntervalType.stringToField(fromUnit)
+ val end = DayTimeIntervalType.stringToField(toUnit)
+ Literal(micros, DayTimeIntervalType(start, end))
+ }
+ } else if (ctx.errorCapturingMultiUnitsInterval != null && !conf.legacyIntervalEnabled) {
+ val units =
+ ctx.errorCapturingMultiUnitsInterval.body.unit.asScala.map(
+ _.getText.toLowerCase(Locale.ROOT).stripSuffix("s")).toSeq
+ constructMultiUnitsIntervalLiteral(ctx, calendarInterval, units)
+ } else {
+ Literal(calendarInterval, CalendarIntervalType)
+ }
+ }
+
+ /**
+ * Create a [[CalendarInterval]] object
+ */
+ protected def parseIntervalLiteral(ctx: IntervalContext): CalendarInterval = withOrigin(ctx) {
+ if (ctx.errorCapturingMultiUnitsInterval != null) {
+ val innerCtx = ctx.errorCapturingMultiUnitsInterval
+ if (innerCtx.unitToUnitInterval != null) {
+ throw moreThanOneFromToUnitInIntervalLiteralError(
+ innerCtx.unitToUnitInterval)
+ }
+ visitMultiUnitsInterval(innerCtx.multiUnitsInterval)
+ } else if (ctx.errorCapturingUnitToUnitInterval != null) {
+ val innerCtx = ctx.errorCapturingUnitToUnitInterval
+ if (innerCtx.error1 != null || innerCtx.error2 != null) {
+ val errorCtx = if (innerCtx.error1 != null) innerCtx.error1 else innerCtx.error2
+ throw moreThanOneFromToUnitInIntervalLiteralError(errorCtx)
+ }
+ visitUnitToUnitInterval(innerCtx.body)
+ } else {
+ throw invalidIntervalLiteralError(ctx)
+ }
+ }
+
+ /**
+ * Creates a [[CalendarInterval]] with multiple unit value pairs, e.g. 1 YEAR 2 DAYS.
+ */
+ override def visitMultiUnitsInterval(ctx: MultiUnitsIntervalContext): CalendarInterval = {
+ withOrigin(ctx) {
+ val units = ctx.unit.asScala
+ val values = ctx.intervalValue().asScala
+ try {
+ assert(units.length == values.length)
+ val kvs = units.indices.map { i =>
+ val u = units(i).getText
+ val v = if (values(i).STRING() != null) {
+ val value = string(values(i).STRING())
+ // SPARK-32840: For invalid cases, e.g. INTERVAL '1 day 2' hour,
+ // INTERVAL 'interval 1' day, we need to check ahead before they are concatenated with
+ // units and become valid ones, e.g. '1 day 2 hour'.
+ // Ideally, we only ensure the value parts don't contain any units here.
+ if (value.exists(Character.isLetter)) {
+ throw invalidIntervalFormError(value, ctx)
+ }
+ if (values(i).MINUS() == null) {
+ value
+ } else {
+ value.startsWith("-") match {
+ case true => value.replaceFirst("-", "")
+ case false => s"-$value"
+ }
+ }
+ } else {
+ values(i).getText
+ }
+ UTF8String.fromString(" " + v + " " + u)
+ }
+ IntervalUtils.stringToInterval(UTF8String.concat(kvs: _*))
+ } catch {
+ case i: IllegalArgumentException =>
+ val e = new ParseException(i.getMessage, ctx)
+ e.setStackTrace(i.getStackTrace)
+ throw e
+ }
+ }
+ }
+
+ /**
+ * Creates a [[CalendarInterval]] with from-to unit, e.g. '2-1' YEAR TO MONTH.
+ */
+ override def visitUnitToUnitInterval(ctx: UnitToUnitIntervalContext): CalendarInterval = {
+ withOrigin(ctx) {
+ val value = Option(ctx.intervalValue.STRING).map(string).map { interval =>
+ if (ctx.intervalValue().MINUS() == null) {
+ interval
+ } else {
+ interval.startsWith("-") match {
+ case true => interval.replaceFirst("-", "")
+ case false => s"-$interval"
+ }
+ }
+ }.getOrElse {
+ throw invalidFromToUnitValueError(ctx.intervalValue)
+ }
+ try {
+ val from = ctx.from.getText.toLowerCase(Locale.ROOT)
+ val to = ctx.to.getText.toLowerCase(Locale.ROOT)
+ (from, to) match {
+ case ("year", "month") =>
+ IntervalUtils.fromYearMonthString(value)
+ case ("day", "hour") | ("day", "minute") | ("day", "second") | ("hour", "minute") |
+ ("hour", "second") | ("minute", "second") =>
+ IntervalUtils.fromDayTimeString(value,
+ DayTimeIntervalType.stringToField(from), DayTimeIntervalType.stringToField(to))
+ case _ =>
+ throw fromToIntervalUnsupportedError(from, to, ctx)
+ }
+ } catch {
+ // Handle Exceptions thrown by CalendarInterval
+ case e: IllegalArgumentException =>
+ val pe = new ParseException(e.getMessage, ctx)
+ pe.setStackTrace(e.getStackTrace)
+ throw pe
+ }
+ }
+ }
+
+ /* ********************************************************************************************
+ * DataType parsing
+ * ******************************************************************************************** */
+
+ /**
+ * Resolve/create a primitive type.
+ */
+ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) {
+ val dataType = ctx.identifier.getText.toLowerCase(Locale.ROOT)
+ (dataType, ctx.INTEGER_VALUE().asScala.toList) match {
+ case ("boolean", Nil) => BooleanType
+ case ("tinyint" | "byte", Nil) => ByteType
+ case ("smallint" | "short", Nil) => ShortType
+ case ("int" | "integer", Nil) => IntegerType
+ case ("bigint" | "long", Nil) => LongType
+ case ("float" | "real", Nil) => FloatType
+ case ("double", Nil) => DoubleType
+ case ("date", Nil) => DateType
+ case ("timestamp", Nil) => SQLConf.get.timestampType
+ // SPARK-36227: Remove TimestampNTZ type support in Spark 3.2 with minimal code changes.
+ case ("timestamp_ntz", Nil) if isTesting => TimestampNTZType
+ case ("timestamp_ltz", Nil) if isTesting => TimestampType
+ case ("string", Nil) => StringType
+ case ("character" | "char", length :: Nil) => CharType(length.getText.toInt)
+ case ("varchar", length :: Nil) => VarcharType(length.getText.toInt)
+ case ("binary", Nil) => BinaryType
+ case ("decimal" | "dec" | "numeric", Nil) => DecimalType.USER_DEFAULT
+ case ("decimal" | "dec" | "numeric", precision :: Nil) =>
+ DecimalType(precision.getText.toInt, 0)
+ case ("decimal" | "dec" | "numeric", precision :: scale :: Nil) =>
+ DecimalType(precision.getText.toInt, scale.getText.toInt)
+ case ("void", Nil) => NullType
+ case ("interval", Nil) => CalendarIntervalType
+ case (dt, params) =>
+ val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt
+ throw dataTypeUnsupportedError(dtStr, ctx)
+ }
+ }
+
+ override def visitYearMonthIntervalDataType(ctx: YearMonthIntervalDataTypeContext): DataType = {
+ val startStr = ctx.from.getText.toLowerCase(Locale.ROOT)
+ val start = YearMonthIntervalType.stringToField(startStr)
+ if (ctx.to != null) {
+ val endStr = ctx.to.getText.toLowerCase(Locale.ROOT)
+ val end = YearMonthIntervalType.stringToField(endStr)
+ if (end <= start) {
+ throw fromToIntervalUnsupportedError(startStr, endStr, ctx)
+ }
+ YearMonthIntervalType(start, end)
+ } else {
+ YearMonthIntervalType(start)
+ }
+ }
+
+ override def visitDayTimeIntervalDataType(ctx: DayTimeIntervalDataTypeContext): DataType = {
+ val startStr = ctx.from.getText.toLowerCase(Locale.ROOT)
+ val start = DayTimeIntervalType.stringToField(startStr)
+ if (ctx.to != null ) {
+ val endStr = ctx.to.getText.toLowerCase(Locale.ROOT)
+ val end = DayTimeIntervalType.stringToField(endStr)
+ if (end <= start) {
+ throw fromToIntervalUnsupportedError(startStr, endStr, ctx)
+ }
+ DayTimeIntervalType(start, end)
+ } else {
+ DayTimeIntervalType(start)
+ }
+ }
+
+ /**
+ * Create a complex DataType. Arrays, Maps and Structures are supported.
+ */
+ override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) {
+ ctx.complex.getType match {
+ case HoodieSqlBaseParser.ARRAY =>
+ ArrayType(typedVisit(ctx.dataType(0)))
+ case HoodieSqlBaseParser.MAP =>
+ MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1)))
+ case HoodieSqlBaseParser.STRUCT =>
+ StructType(Option(ctx.complexColTypeList).toSeq.flatMap(visitComplexColTypeList))
+ }
+ }
+
+ /**
+ * Create top level table schema.
+ */
+ protected def createSchema(ctx: ColTypeListContext): StructType = {
+ StructType(Option(ctx).toSeq.flatMap(visitColTypeList))
+ }
+
+ /**
+ * Create a [[StructType]] from a number of column definitions.
+ */
+ override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) {
+ ctx.colType().asScala.map(visitColType).toSeq
+ }
+
+ /**
+ * Create a top level [[StructField]] from a column definition.
+ */
+ override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) {
+ import ctx._
+
+ val builder = new MetadataBuilder
+ // Add comment to metadata
+ Option(commentSpec()).map(visitCommentSpec).foreach {
+ builder.putString("comment", _)
+ }
+
+ StructField(
+ name = colName.getText,
+ dataType = typedVisit[DataType](ctx.dataType),
+ nullable = NULL == null,
+ metadata = builder.build())
+ }
+
+ /**
+ * Create a [[StructType]] from a sequence of [[StructField]]s.
+ */
+ protected def createStructType(ctx: ComplexColTypeListContext): StructType = {
+ StructType(Option(ctx).toSeq.flatMap(visitComplexColTypeList))
+ }
+
+ /**
+ * Create a [[StructType]] from a number of column definitions.
+ */
+ override def visitComplexColTypeList(
+ ctx: ComplexColTypeListContext): Seq[StructField] = withOrigin(ctx) {
+ ctx.complexColType().asScala.map(visitComplexColType).toSeq
+ }
+
+ /**
+ * Create a [[StructField]] from a column definition.
+ */
+ override def visitComplexColType(ctx: ComplexColTypeContext): StructField = withOrigin(ctx) {
+ import ctx._
+ val structField = StructField(
+ name = identifier.getText,
+ dataType = typedVisit(dataType()),
+ nullable = NULL == null)
+ Option(commentSpec).map(visitCommentSpec).map(structField.withComment).getOrElse(structField)
+ }
+
+ /**
+ * Create a location string.
+ */
+ override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) {
+ string(ctx.STRING)
+ }
+
+ /**
+ * Create an optional location string.
+ */
+ protected def visitLocationSpecList(ctx: java.util.List[LocationSpecContext]): Option[String] = {
+ ctx.asScala.headOption.map(visitLocationSpec)
+ }
+
+ /**
+ * Create a comment string.
+ */
+ override def visitCommentSpec(ctx: CommentSpecContext): String = withOrigin(ctx) {
+ string(ctx.STRING)
+ }
+
+ /**
+ * Create an optional comment string.
+ */
+ protected def visitCommentSpecList(ctx: java.util.List[CommentSpecContext]): Option[String] = {
+ ctx.asScala.headOption.map(visitCommentSpec)
+ }
+
+ /**
+ * Create a [[BucketSpec]].
+ */
+ override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) {
+ BucketSpec(
+ ctx.INTEGER_VALUE.getText.toInt,
+ visitIdentifierList(ctx.identifierList),
+ Option(ctx.orderedIdentifierList)
+ .toSeq
+ .flatMap(_.orderedIdentifier.asScala)
+ .map { orderedIdCtx =>
+ Option(orderedIdCtx.ordering).map(_.getText).foreach { dir =>
+ if (dir.toLowerCase(Locale.ROOT) != "asc") {
+ operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx)
+ }
+ }
+
+ orderedIdCtx.ident.getText
+ })
+ }
+
+ /**
+ * Convert a table property list into a key-value map.
+ * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]].
+ */
+ override def visitTablePropertyList(
+ ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) {
+ val properties = ctx.tableProperty.asScala.map { property =>
+ val key = visitTablePropertyKey(property.key)
+ val value = visitTablePropertyValue(property.value)
+ key -> value
+ }
+ // Check for duplicate property names.
+ checkDuplicateKeys(properties.toSeq, ctx)
+ properties.toMap
+ }
+
+ /**
+ * Parse a key-value map from a [[TablePropertyListContext]], assuming all values are specified.
+ */
+ def visitPropertyKeyValues(ctx: TablePropertyListContext): Map[String, String] = {
+ val props = visitTablePropertyList(ctx)
+ val badKeys = props.collect { case (key, null) => key }
+ if (badKeys.nonEmpty) {
+ operationNotAllowed(
+ s"Values must be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx)
+ }
+ props
+ }
+
+ /**
+ * Parse a list of keys from a [[TablePropertyListContext]], assuming no values are specified.
+ */
+ def visitPropertyKeys(ctx: TablePropertyListContext): Seq[String] = {
+ val props = visitTablePropertyList(ctx)
+ val badKeys = props.filter { case (_, v) => v != null }.keys
+ if (badKeys.nonEmpty) {
+ operationNotAllowed(
+ s"Values should not be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx)
+ }
+ props.keys.toSeq
+ }
+
+ /**
+ * A table property key can either be String or a collection of dot separated elements. This
+ * function extracts the property key based on whether its a string literal or a table property
+ * identifier.
+ */
+ override def visitTablePropertyKey(key: TablePropertyKeyContext): String = {
+ if (key.STRING != null) {
+ string(key.STRING)
+ } else {
+ key.getText
+ }
+ }
+
+ /**
+ * A table property value can be String, Integer, Boolean or Decimal. This function extracts
+ * the property value based on whether its a string, integer, boolean or decimal literal.
+ */
+ override def visitTablePropertyValue(value: TablePropertyValueContext): String = {
+ if (value == null) {
+ null
+ } else if (value.STRING != null) {
+ string(value.STRING)
+ } else if (value.booleanValue != null) {
+ value.getText.toLowerCase(Locale.ROOT)
+ } else {
+ value.getText
+ }
+ }
+
+ /**
+ * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal).
+ */
+ type TableHeader = (Seq[String], Boolean, Boolean, Boolean)
+
+ /**
+ * Type to keep track of table clauses:
+ * - partition transforms
+ * - partition columns
+ * - bucketSpec
+ * - properties
+ * - options
+ * - location
+ * - comment
+ * - serde
+ *
+ * Note: Partition transforms are based on existing table schema definition. It can be simple
+ * column names, or functions like `year(date_col)`. Partition columns are column names with data
+ * types like `i INT`, which should be appended to the existing table schema.
+ */
+ type TableClauses = (
+ Seq[Transform], Seq[StructField], Option[BucketSpec], Map[String, String],
+ Map[String, String], Option[String], Option[String], Option[SerdeInfo])
+
+ /**
+ * Validate a create table statement and return the [[TableIdentifier]].
+ */
+ override def visitCreateTableHeader(
+ ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) {
+ val temporary = ctx.TEMPORARY != null
+ val ifNotExists = ctx.EXISTS != null
+ if (temporary && ifNotExists) {
+ operationNotAllowed("CREATE TEMPORARY TABLE ... IF NOT EXISTS", ctx)
+ }
+ val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText).toSeq
+ (multipartIdentifier, temporary, ifNotExists, ctx.EXTERNAL != null)
+ }
+
+ /**
+ * Validate a replace table statement and return the [[TableIdentifier]].
+ */
+ override def visitReplaceTableHeader(
+ ctx: ReplaceTableHeaderContext): TableHeader = withOrigin(ctx) {
+ val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText).toSeq
+ (multipartIdentifier, false, false, false)
+ }
+
+ /**
+ * Parse a qualified name to a multipart name.
+ */
+ override def visitQualifiedName(ctx: QualifiedNameContext): Seq[String] = withOrigin(ctx) {
+ ctx.identifier.asScala.map(_.getText).toSeq
+ }
+
+ /**
+ * Parse a list of transforms or columns.
+ */
+ override def visitPartitionFieldList(
+ ctx: PartitionFieldListContext): (Seq[Transform], Seq[StructField]) = withOrigin(ctx) {
+ val (transforms, columns) = ctx.fields.asScala.map {
+ case transform: PartitionTransformContext =>
+ (Some(visitPartitionTransform(transform)), None)
+ case field: PartitionColumnContext =>
+ (None, Some(visitColType(field.colType)))
+ }.unzip
+
+ (transforms.flatten.toSeq, columns.flatten.toSeq)
+ }
+
+ override def visitPartitionTransform(
+ ctx: PartitionTransformContext): Transform = withOrigin(ctx) {
+ def getFieldReference(
+ ctx: ApplyTransformContext,
+ arg: V2Expression): FieldReference = {
+ lazy val name: String = ctx.identifier.getText
+ arg match {
+ case ref: FieldReference =>
+ ref
+ case nonRef =>
+ throw partitionTransformNotExpectedError(name, nonRef.describe, ctx)
+ }
+ }
+
+ def getSingleFieldReference(
+ ctx: ApplyTransformContext,
+ arguments: Seq[V2Expression]): FieldReference = {
+ lazy val name: String = ctx.identifier.getText
+ if (arguments.size > 1) {
+ throw tooManyArgumentsForTransformError(name, ctx)
+ } else if (arguments.isEmpty) {
+ throw notEnoughArgumentsForTransformError(name, ctx)
+ } else {
+ getFieldReference(ctx, arguments.head)
+ }
+ }
+
+ ctx.transform match {
+ case identityCtx: IdentityTransformContext =>
+ IdentityTransform(FieldReference(typedVisit[Seq[String]](identityCtx.qualifiedName)))
+
+ case applyCtx: ApplyTransformContext =>
+ val arguments = applyCtx.argument.asScala.map(visitTransformArgument).toSeq
+
+ applyCtx.identifier.getText match {
+ case "bucket" =>
+ val numBuckets: Int = arguments.head match {
+ case LiteralValue(shortValue, ShortType) =>
+ shortValue.asInstanceOf[Short].toInt
+ case LiteralValue(intValue, IntegerType) =>
+ intValue.asInstanceOf[Int]
+ case LiteralValue(longValue, LongType) =>
+ longValue.asInstanceOf[Long].toInt
+ case lit =>
+ throw invalidBucketsNumberError(lit.describe, applyCtx)
+ }
+
+ val fields = arguments.tail.map(arg => getFieldReference(applyCtx, arg))
+
+ BucketTransform(LiteralValue(numBuckets, IntegerType), fields)
+
+ case "years" =>
+ YearsTransform(getSingleFieldReference(applyCtx, arguments))
+
+ case "months" =>
+ MonthsTransform(getSingleFieldReference(applyCtx, arguments))
+
+ case "days" =>
+ DaysTransform(getSingleFieldReference(applyCtx, arguments))
+
+ case "hours" =>
+ HoursTransform(getSingleFieldReference(applyCtx, arguments))
+
+ case name =>
+ ApplyTransform(name, arguments)
+ }
+ }
+ }
+
+ /**
+ * Parse an argument to a transform. An argument may be a field reference (qualified name) or
+ * a value literal.
+ */
+ override def visitTransformArgument(ctx: TransformArgumentContext): V2Expression = {
+ withOrigin(ctx) {
+ val reference = Option(ctx.qualifiedName)
+ .map(typedVisit[Seq[String]])
+ .map(FieldReference(_))
+ val literal = Option(ctx.constant)
+ .map(typedVisit[Literal])
+ .map(lit => LiteralValue(lit.value, lit.dataType))
+ reference.orElse(literal)
+ .getOrElse(throw invalidTransformArgumentError(ctx))
+ }
+ }
+
+ private def cleanNamespaceProperties(
+ properties: Map[String, String],
+ ctx: ParserRuleContext): Map[String, String] = withOrigin(ctx) {
+ import SupportsNamespaces._
+ val legacyOn = conf.getConf(SQLConf.LEGACY_PROPERTY_NON_RESERVED)
+ properties.filter {
+ case (PROP_LOCATION, _) if !legacyOn =>
+ throw cannotCleanReservedNamespacePropertyError(
+ PROP_LOCATION, ctx, "please use the LOCATION clause to specify it")
+ case (PROP_LOCATION, _) => false
+ case (PROP_OWNER, _) if !legacyOn =>
+ throw cannotCleanReservedNamespacePropertyError(
+ PROP_OWNER, ctx, "it will be set to the current user")
+ case (PROP_OWNER, _) => false
+ case _ => true
+ }
+ }
+
+ def cleanTableProperties(
+ ctx: ParserRuleContext, properties: Map[String, String]): Map[String, String] = {
+ import TableCatalog._
+ val legacyOn = conf.getConf(SQLConf.LEGACY_PROPERTY_NON_RESERVED)
+ properties.filter {
+ case (PROP_PROVIDER, _) if !legacyOn =>
+ throw cannotCleanReservedTablePropertyError(
+ PROP_PROVIDER, ctx, "please use the USING clause to specify it")
+ case (PROP_PROVIDER, _) => false
+ case (PROP_LOCATION, _) if !legacyOn =>
+ throw cannotCleanReservedTablePropertyError(
+ PROP_LOCATION, ctx, "please use the LOCATION clause to specify it")
+ case (PROP_LOCATION, _) => false
+ case (PROP_OWNER, _) if !legacyOn =>
+ throw cannotCleanReservedTablePropertyError(
+ PROP_OWNER, ctx, "it will be set to the current user")
+ case (PROP_OWNER, _) => false
+ case _ => true
+ }
+ }
+
+ def cleanTableOptions(
+ ctx: ParserRuleContext,
+ options: Map[String, String],
+ location: Option[String]): (Map[String, String], Option[String]) = {
+ var path = location
+ val filtered = cleanTableProperties(ctx, options).filter {
+ case (k, v) if k.equalsIgnoreCase("path") && path.nonEmpty =>
+ throw duplicatedTablePathsFoundError(path.get, v, ctx)
+ case (k, v) if k.equalsIgnoreCase("path") =>
+ path = Some(v)
+ false
+ case _ => true
+ }
+ (filtered, path)
+ }
+
+ /**
+ * Create a [[SerdeInfo]] for creating tables.
+ *
+ * Format: STORED AS (name | INPUTFORMAT input_format OUTPUTFORMAT output_format)
+ */
+ override def visitCreateFileFormat(ctx: CreateFileFormatContext): SerdeInfo = withOrigin(ctx) {
+ (ctx.fileFormat, ctx.storageHandler) match {
+ // Expected format: INPUTFORMAT input_format OUTPUTFORMAT output_format
+ case (c: TableFileFormatContext, null) =>
+ SerdeInfo(formatClasses = Some(FormatClasses(string(c.inFmt), string(c.outFmt))))
+ // Expected format: SEQUENCEFILE | TEXTFILE | RCFILE | ORC | PARQUET | AVRO
+ case (c: GenericFileFormatContext, null) =>
+ SerdeInfo(storedAs = Some(c.identifier.getText))
+ case (null, storageHandler) =>
+ operationNotAllowed("STORED BY", ctx)
+ case _ =>
+ throw storedAsAndStoredByBothSpecifiedError(ctx)
+ }
+ }
+
+ /**
+ * Create a [[SerdeInfo]] used for creating tables.
+ *
+ * Example format:
+ * {{{
+ * SERDE serde_name [WITH SERDEPROPERTIES (k1=v1, k2=v2, ...)]
+ * }}}
+ *
+ * OR
+ *
+ * {{{
+ * DELIMITED [FIELDS TERMINATED BY char [ESCAPED BY char]]
+ * [COLLECTION ITEMS TERMINATED BY char]
+ * [MAP KEYS TERMINATED BY char]
+ * [LINES TERMINATED BY char]
+ * [NULL DEFINED AS char]
+ * }}}
+ */
+ def visitRowFormat(ctx: RowFormatContext): SerdeInfo = withOrigin(ctx) {
+ ctx match {
+ case serde: RowFormatSerdeContext => visitRowFormatSerde(serde)
+ case delimited: RowFormatDelimitedContext => visitRowFormatDelimited(delimited)
+ }
+ }
+
+ /**
+ * Create SERDE row format name and properties pair.
+ */
+ override def visitRowFormatSerde(ctx: RowFormatSerdeContext): SerdeInfo = withOrigin(ctx) {
+ import ctx._
+ SerdeInfo(
+ serde = Some(string(name)),
+ serdeProperties = Option(tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty))
+ }
+
+ /**
+ * Create a delimited row format properties object.
+ */
+ override def visitRowFormatDelimited(
+ ctx: RowFormatDelimitedContext): SerdeInfo = withOrigin(ctx) {
+ // Collect the entries if any.
+ def entry(key: String, value: Token): Seq[(String, String)] = {
+ Option(value).toSeq.map(x => key -> string(x))
+ }
+ // TODO we need proper support for the NULL format.
+ val entries =
+ entry("field.delim", ctx.fieldsTerminatedBy) ++
+ entry("serialization.format", ctx.fieldsTerminatedBy) ++
+ entry("escape.delim", ctx.escapedBy) ++
+ // The following typo is inherited from Hive...
+ entry("colelction.delim", ctx.collectionItemsTerminatedBy) ++
+ entry("mapkey.delim", ctx.keysTerminatedBy) ++
+ Option(ctx.linesSeparatedBy).toSeq.map { token =>
+ val value = string(token)
+ validate(
+ value == "\n",
+ s"LINES TERMINATED BY only supports newline '\\n' right now: $value",
+ ctx)
+ "line.delim" -> value
+ }
+ SerdeInfo(serdeProperties = entries.toMap)
+ }
+
+ /**
+ * Throw a [[ParseException]] if the user specified incompatible SerDes through ROW FORMAT
+ * and STORED AS.
+ *
+ * The following are allowed. Anything else is not:
+ * ROW FORMAT SERDE ... STORED AS [SEQUENCEFILE | RCFILE | TEXTFILE]
+ * ROW FORMAT DELIMITED ... STORED AS TEXTFILE
+ * ROW FORMAT ... STORED AS INPUTFORMAT ... OUTPUTFORMAT ...
+ */
+ protected def validateRowFormatFileFormat(
+ rowFormatCtx: RowFormatContext,
+ createFileFormatCtx: CreateFileFormatContext,
+ parentCtx: ParserRuleContext): Unit = {
+ if (!(rowFormatCtx == null || createFileFormatCtx == null)) {
+ (rowFormatCtx, createFileFormatCtx.fileFormat) match {
+ case (_, ffTable: TableFileFormatContext) => // OK
+ case (rfSerde: RowFormatSerdeContext, ffGeneric: GenericFileFormatContext) =>
+ ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match {
+ case ("sequencefile" | "textfile" | "rcfile") => // OK
+ case fmt =>
+ operationNotAllowed(
+ s"ROW FORMAT SERDE is incompatible with format '$fmt', which also specifies a serde",
+ parentCtx)
+ }
+ case (rfDelimited: RowFormatDelimitedContext, ffGeneric: GenericFileFormatContext) =>
+ ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match {
+ case "textfile" => // OK
+ case fmt => operationNotAllowed(
+ s"ROW FORMAT DELIMITED is only compatible with 'textfile', not '$fmt'", parentCtx)
+ }
+ case _ =>
+ // should never happen
+ def str(ctx: ParserRuleContext): String = {
+ (0 until ctx.getChildCount).map { i => ctx.getChild(i).getText }.mkString(" ")
+ }
+
+ operationNotAllowed(
+ s"Unexpected combination of ${str(rowFormatCtx)} and ${str(createFileFormatCtx)}",
+ parentCtx)
+ }
+ }
+ }
+
+ protected def validateRowFormatFileFormat(
+ rowFormatCtx: Seq[RowFormatContext],
+ createFileFormatCtx: Seq[CreateFileFormatContext],
+ parentCtx: ParserRuleContext): Unit = {
+ if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) {
+ validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx)
+ }
+ }
+
+ override def visitCreateTableClauses(ctx: CreateTableClausesContext): TableClauses = {
+ checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx)
+ checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx)
+ checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx)
+ checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx)
+ checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx)
+ checkDuplicateClauses(ctx.commentSpec(), "COMMENT", ctx)
+ checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx)
+ checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx)
+
+ if (ctx.skewSpec.size > 0) {
+ operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx)
+ }
+
+ val (partTransforms, partCols) =
+ Option(ctx.partitioning).map(visitPartitionFieldList).getOrElse((Nil, Nil))
+ val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec)
+ val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty)
+ val cleanedProperties = cleanTableProperties(ctx, properties)
+ val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty)
+ val location = visitLocationSpecList(ctx.locationSpec())
+ val (cleanedOptions, newLocation) = cleanTableOptions(ctx, options, location)
+ val comment = visitCommentSpecList(ctx.commentSpec())
+ val serdeInfo =
+ getSerdeInfo(ctx.rowFormat.asScala.toSeq, ctx.createFileFormat.asScala.toSeq, ctx)
+ (partTransforms, partCols, bucketSpec, cleanedProperties, cleanedOptions, newLocation, comment,
+ serdeInfo)
+ }
+
+ protected def getSerdeInfo(
+ rowFormatCtx: Seq[RowFormatContext],
+ createFileFormatCtx: Seq[CreateFileFormatContext],
+ ctx: ParserRuleContext): Option[SerdeInfo] = {
+ validateRowFormatFileFormat(rowFormatCtx, createFileFormatCtx, ctx)
+ val rowFormatSerdeInfo = rowFormatCtx.map(visitRowFormat)
+ val fileFormatSerdeInfo = createFileFormatCtx.map(visitCreateFileFormat)
+ (fileFormatSerdeInfo ++ rowFormatSerdeInfo).reduceLeftOption((l, r) => l.merge(r))
+ }
+
+ private def partitionExpressions(
+ partTransforms: Seq[Transform],
+ partCols: Seq[StructField],
+ ctx: ParserRuleContext): Seq[Transform] = {
+ if (partTransforms.nonEmpty) {
+ if (partCols.nonEmpty) {
+ val references = partTransforms.map(_.describe()).mkString(", ")
+ val columns = partCols
+ .map(field => s"${field.name} ${field.dataType.simpleString}")
+ .mkString(", ")
+ operationNotAllowed(
+ s"""PARTITION BY: Cannot mix partition expressions and partition columns:
+ |Expressions: $references
+ |Columns: $columns""".stripMargin, ctx)
+
+ }
+ partTransforms
+ } else {
+ // columns were added to create the schema. convert to column references
+ partCols.map { column =>
+ IdentityTransform(FieldReference(Seq(column.name)))
+ }
+ }
+ }
+
+ /**
+ * Create a table, returning a [[CreateTableStatement]] logical plan.
+ *
+ * Expected format:
+ * {{{
+ * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name
+ * [USING table_provider]
+ * create_table_clauses
+ * [[AS] select_statement];
+ *
+ * create_table_clauses (order insensitive):
+ * [PARTITIONED BY (partition_fields)]
+ * [OPTIONS table_property_list]
+ * [ROW FORMAT row_format]
+ * [STORED AS file_format]
+ * [CLUSTERED BY (col_name, col_name, ...)
+ * [SORTED BY (col_name [ASC|DESC], ...)]
+ * INTO num_buckets BUCKETS
+ * ]
+ * [LOCATION path]
+ * [COMMENT table_comment]
+ * [TBLPROPERTIES (property_name=property_value, ...)]
+ *
+ * partition_fields:
+ * col_name, transform(col_name), transform(constant, col_name), ... |
+ * col_name data_type [NOT NULL] [COMMENT col_comment], ...
+ * }}}
+ */
+ override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) {
+ val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader)
+
+ val columns = Option(ctx.colTypeList()).map(visitColTypeList).getOrElse(Nil)
+ val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText)
+ val (partTransforms, partCols, bucketSpec, properties, options, location, comment, serdeInfo) =
+ visitCreateTableClauses(ctx.createTableClauses())
+
+ if (provider.isDefined && serdeInfo.isDefined) {
+ operationNotAllowed(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx)
+ }
+
+ if (temp) {
+ val asSelect = if (ctx.query == null) "" else " AS ..."
+ operationNotAllowed(
+ s"CREATE TEMPORARY TABLE ...$asSelect, use CREATE TEMPORARY VIEW instead", ctx)
+ }
+
+ val partitioning = partitionExpressions(partTransforms, partCols, ctx)
+
+ Option(ctx.query).map(plan) match {
+ case Some(_) if columns.nonEmpty =>
+ operationNotAllowed(
+ "Schema may not be specified in a Create Table As Select (CTAS) statement",
+ ctx)
+
+ case Some(_) if partCols.nonEmpty =>
+ // non-reference partition columns are not allowed because schema can't be specified
+ operationNotAllowed(
+ "Partition column types may not be specified in Create Table As Select (CTAS)",
+ ctx)
+
+ case Some(query) =>
+ CreateTableAsSelectStatement(
+ table, query, partitioning, bucketSpec, properties, provider, options, location, comment,
+ writeOptions = Map.empty, serdeInfo, external = external, ifNotExists = ifNotExists)
+
+ case _ =>
+ // Note: table schema includes both the table columns list and the partition columns
+ // with data type.
+ val schema = StructType(columns ++ partCols)
+ CreateTableStatement(table, schema, partitioning, bucketSpec, properties, provider,
+ options, location, comment, serdeInfo, external = external, ifNotExists = ifNotExists)
+ }
+ }
+
+ /**
+ * Replace a table, returning a [[ReplaceTableStatement]] logical plan.
+ *
+ * Expected format:
+ * {{{
+ * [CREATE OR] REPLACE TABLE [db_name.]table_name
+ * [USING table_provider]
+ * replace_table_clauses
+ * [[AS] select_statement];
+ *
+ * replace_table_clauses (order insensitive):
+ * [OPTIONS table_property_list]
+ * [PARTITIONED BY (partition_fields)]
+ * [CLUSTERED BY (col_name, col_name, ...)
+ * [SORTED BY (col_name [ASC|DESC], ...)]
+ * INTO num_buckets BUCKETS
+ * ]
+ * [LOCATION path]
+ * [COMMENT table_comment]
+ * [TBLPROPERTIES (property_name=property_value, ...)]
+ *
+ * partition_fields:
+ * col_name, transform(col_name), transform(constant, col_name), ... |
+ * col_name data_type [NOT NULL] [COMMENT col_comment], ...
+ * }}}
+ */
+ override def visitReplaceTable(ctx: ReplaceTableContext): LogicalPlan = withOrigin(ctx) {
+ val (table, temp, ifNotExists, external) = visitReplaceTableHeader(ctx.replaceTableHeader)
+ val orCreate = ctx.replaceTableHeader().CREATE() != null
+
+ if (temp) {
+ val action = if (orCreate) "CREATE OR REPLACE" else "REPLACE"
+ operationNotAllowed(s"$action TEMPORARY TABLE ..., use $action TEMPORARY VIEW instead.", ctx)
+ }
+
+ if (external) {
+ operationNotAllowed("REPLACE EXTERNAL TABLE ...", ctx)
+ }
+
+ if (ifNotExists) {
+ operationNotAllowed("REPLACE ... IF NOT EXISTS, use CREATE IF NOT EXISTS instead", ctx)
+ }
+
+ val (partTransforms, partCols, bucketSpec, properties, options, location, comment, serdeInfo) =
+ visitCreateTableClauses(ctx.createTableClauses())
+ val columns = Option(ctx.colTypeList()).map(visitColTypeList).getOrElse(Nil)
+ val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText)
+
+ if (provider.isDefined && serdeInfo.isDefined) {
+ operationNotAllowed(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx)
+ }
+
+ val partitioning = partitionExpressions(partTransforms, partCols, ctx)
+
+ Option(ctx.query).map(plan) match {
+ case Some(_) if columns.nonEmpty =>
+ operationNotAllowed(
+ "Schema may not be specified in a Replace Table As Select (RTAS) statement",
+ ctx)
+
+ case Some(_) if partCols.nonEmpty =>
+ // non-reference partition columns are not allowed because schema can't be specified
+ operationNotAllowed(
+ "Partition column types may not be specified in Replace Table As Select (RTAS)",
+ ctx)
+
+ case Some(query) =>
+ ReplaceTableAsSelectStatement(table, query, partitioning, bucketSpec, properties,
+ provider, options, location, comment, writeOptions = Map.empty, serdeInfo,
+ orCreate = orCreate)
+
+ case _ =>
+ // Note: table schema includes both the table columns list and the partition columns
+ // with data type.
+ val schema = StructType(columns ++ partCols)
+ ReplaceTableStatement(table, schema, partitioning, bucketSpec, properties, provider,
+ options, location, comment, serdeInfo, orCreate = orCreate)
+ }
+ }
+
+ /**
+ * Parse new column info from ADD COLUMN into a QualifiedColType.
+ */
+ override def visitQualifiedColTypeWithPosition(
+ ctx: QualifiedColTypeWithPositionContext): QualifiedColType = withOrigin(ctx) {
+ val name = typedVisit[Seq[String]](ctx.name)
+ QualifiedColType(
+ path = if (name.length > 1) Some(UnresolvedFieldName(name.init)) else None,
+ colName = name.last,
+ dataType = typedVisit[DataType](ctx.dataType),
+ nullable = ctx.NULL == null,
+ comment = Option(ctx.commentSpec()).map(visitCommentSpec),
+ position = Option(ctx.colPosition).map( pos =>
+ UnresolvedFieldPosition(typedVisit[ColumnPosition](pos))))
+ }
+
+ /**
+ * Create a [[CacheTable]] or [[CacheTableAsSelect]].
+ *
+ * For example:
+ * {{{
+ * CACHE [LAZY] TABLE multi_part_name
+ * [OPTIONS tablePropertyList] [[AS] query]
+ * }}}
+ */
+ override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ val query = Option(ctx.query).map(plan)
+ val relation = createUnresolvedRelation(ctx.multipartIdentifier)
+ val tableName = relation.multipartIdentifier
+ if (query.isDefined && tableName.length > 1) {
+ val catalogAndNamespace = tableName.init
+ throw addCatalogInCacheTableAsSelectNotAllowedError(
+ catalogAndNamespace.quoted, ctx)
+ }
+ val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty)
+ val isLazy = ctx.LAZY != null
+ if (query.isDefined) {
+ CacheTableAsSelect(tableName.head, query.get, source(ctx.query()), isLazy, options)
+ } else {
+ CacheTable(relation, tableName, isLazy, options)
+ }
+ }
+
+ /**
+ * Create or replace a view. This creates a [[CreateViewStatement]]
+ *
+ * For example:
+ * {{{
+ * CREATE [OR REPLACE] [[GLOBAL] TEMPORARY] VIEW [IF NOT EXISTS] multi_part_name
+ * [(column_name [COMMENT column_comment], ...) ]
+ * create_view_clauses
+ *
+ * AS SELECT ...;
+ *
+ * create_view_clauses (order insensitive):
+ * [COMMENT view_comment]
+ * [TBLPROPERTIES (property_name = property_value, ...)]
+ * }}}
+ */
+ override def visitCreateView(ctx: CreateViewContext): LogicalPlan = withOrigin(ctx) {
+ if (!ctx.identifierList.isEmpty) {
+ operationNotAllowed("CREATE VIEW ... PARTITIONED ON", ctx)
+ }
+
+ checkDuplicateClauses(ctx.commentSpec(), "COMMENT", ctx)
+ checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED ON", ctx)
+ checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx)
+
+ val userSpecifiedColumns = Option(ctx.identifierCommentList).toSeq.flatMap { icl =>
+ icl.identifierComment.asScala.map { ic =>
+ ic.identifier.getText -> Option(ic.commentSpec()).map(visitCommentSpec)
+ }
+ }
+
+ val properties = ctx.tablePropertyList.asScala.headOption.map(visitPropertyKeyValues)
+ .getOrElse(Map.empty)
+ if (ctx.TEMPORARY != null && !properties.isEmpty) {
+ operationNotAllowed("TBLPROPERTIES can't coexist with CREATE TEMPORARY VIEW", ctx)
+ }
+
+ val viewType = if (ctx.TEMPORARY == null) {
+ PersistedView
+ } else if (ctx.GLOBAL != null) {
+ GlobalTempView
+ } else {
+ LocalTempView
+ }
+ CreateViewStatement(
+ visitMultipartIdentifier(ctx.multipartIdentifier),
+ userSpecifiedColumns,
+ visitCommentSpecList(ctx.commentSpec()),
+ properties,
+ Option(source(ctx.query)),
+ plan(ctx.query),
+ ctx.EXISTS != null,
+ ctx.REPLACE != null,
+ viewType)
+ }
+
+ /**
+ * Alter the query of a view. This creates a [[AlterViewAs]]
+ *
+ * For example:
+ * {{{
+ * ALTER VIEW multi_part_name AS SELECT ...;
+ * }}}
+ */
+ override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) {
+ AlterViewAs(
+ createUnresolvedView(ctx.multipartIdentifier, "ALTER VIEW ... AS"),
+ originalText = source(ctx.query),
+ query = plan(ctx.query))
+ }
+
+ private def alterViewTypeMismatchHint: Option[String] = Some("Please use ALTER TABLE instead.")
+
+ private def alterTableTypeMismatchHint: Option[String] = Some("Please use ALTER VIEW instead.")
+
+ def invalidInsertIntoError(ctx: InsertIntoContext): Throwable = {
+ new ParseException("Invalid InsertIntoContext", ctx)
+ }
+
+ def insertOverwriteDirectoryUnsupportedError(ctx: InsertIntoContext): Throwable = {
+ new ParseException("INSERT OVERWRITE DIRECTORY is not supported", ctx)
+ }
+
+ def columnAliasInOperationNotAllowedError(op: String, ctx: TableAliasContext): Throwable = {
+ new ParseException(s"Columns aliases are not allowed in $op.", ctx.identifierList())
+ }
+
+ def emptySourceForMergeError(ctx: MergeIntoTableContext): Throwable = {
+ new ParseException("Empty source for merge: you should specify a source" +
+ " table/subquery in merge.", ctx.source)
+ }
+
+ def unrecognizedMatchedActionError(ctx: MatchedClauseContext): Throwable = {
+ new ParseException(s"Unrecognized matched action: ${ctx.matchedAction().getText}",
+ ctx.matchedAction())
+ }
+
+ def insertedValueNumberNotMatchFieldNumberError(ctx: NotMatchedClauseContext): Throwable = {
+ new ParseException("The number of inserted values cannot match the fields.",
+ ctx.notMatchedAction())
+ }
+
+ def unrecognizedNotMatchedActionError(ctx: NotMatchedClauseContext): Throwable = {
+ new ParseException(s"Unrecognized not matched action: ${ctx.notMatchedAction().getText}",
+ ctx.notMatchedAction())
+ }
+
+ def mergeStatementWithoutWhenClauseError(ctx: MergeIntoTableContext): Throwable = {
+ new ParseException("There must be at least one WHEN clause in a MERGE statement", ctx)
+ }
+
+ def nonLastMatchedClauseOmitConditionError(ctx: MergeIntoTableContext): Throwable = {
+ new ParseException("When there are more than one MATCHED clauses in a MERGE " +
+ "statement, only the last MATCHED clause can omit the condition.", ctx)
+ }
+
+ def nonLastNotMatchedClauseOmitConditionError(ctx: MergeIntoTableContext): Throwable = {
+ new ParseException("When there are more than one NOT MATCHED clauses in a MERGE " +
+ "statement, only the last NOT MATCHED clause can omit the condition.", ctx)
+ }
+
+ def emptyPartitionKeyError(key: String, ctx: PartitionSpecContext): Throwable = {
+ new ParseException(s"Found an empty partition key '$key'.", ctx)
+ }
+
+ def combinationQueryResultClausesUnsupportedError(ctx: QueryOrganizationContext): Throwable = {
+ new ParseException(
+ "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx)
+ }
+
+ def distributeByUnsupportedError(ctx: QueryOrganizationContext): Throwable = {
+ new ParseException("DISTRIBUTE BY is not supported", ctx)
+ }
+
+ def transformNotSupportQuantifierError(ctx: ParserRuleContext): Throwable = {
+ new ParseException("TRANSFORM does not support DISTINCT/ALL in inputs", ctx)
+ }
+
+ def transformWithSerdeUnsupportedError(ctx: ParserRuleContext): Throwable = {
+ new ParseException("TRANSFORM with serde is only supported in hive mode", ctx)
+ }
+
+ def lateralWithPivotInFromClauseNotAllowedError(ctx: FromClauseContext): Throwable = {
+ new ParseException("LATERAL cannot be used together with PIVOT in FROM clause", ctx)
+ }
+
+ def lateralJoinWithNaturalJoinUnsupportedError(ctx: ParserRuleContext): Throwable = {
+ new ParseException("LATERAL join with NATURAL join is not supported", ctx)
+ }
+
+ def lateralJoinWithUsingJoinUnsupportedError(ctx: ParserRuleContext): Throwable = {
+ new ParseException("LATERAL join with USING join is not supported", ctx)
+ }
+
+ def unsupportedLateralJoinTypeError(ctx: ParserRuleContext, joinType: String): Throwable = {
+ new ParseException(s"Unsupported LATERAL join type $joinType", ctx)
+ }
+
+ def invalidLateralJoinRelationError(ctx: RelationPrimaryContext): Throwable = {
+ new ParseException(s"LATERAL can only be used with subquery", ctx)
+ }
+
+ def repetitiveWindowDefinitionError(name: String, ctx: WindowClauseContext): Throwable = {
+ new ParseException(s"The definition of window '$name' is repetitive", ctx)
+ }
+
+ def invalidWindowReferenceError(name: String, ctx: WindowClauseContext): Throwable = {
+ new ParseException(s"Window reference '$name' is not a window specification", ctx)
+ }
+
+ def cannotResolveWindowReferenceError(name: String, ctx: WindowClauseContext): Throwable = {
+ new ParseException(s"Cannot resolve window reference '$name'", ctx)
+ }
+
+ def joinCriteriaUnimplementedError(join: JoinCriteriaContext, ctx: RelationContext): Throwable = {
+ new ParseException(s"Unimplemented joinCriteria: $join", ctx)
+ }
+
+ def naturalCrossJoinUnsupportedError(ctx: RelationContext): Throwable = {
+ new ParseException("NATURAL CROSS JOIN is not supported", ctx)
+ }
+
+ def emptyInputForTableSampleError(ctx: ParserRuleContext): Throwable = {
+ new ParseException("TABLESAMPLE does not accept empty inputs.", ctx)
+ }
+
+ def tableSampleByBytesUnsupportedError(msg: String, ctx: SampleMethodContext): Throwable = {
+ new ParseException(s"TABLESAMPLE($msg) is not supported", ctx)
+ }
+
+ def invalidByteLengthLiteralError(bytesStr: String, ctx: SampleByBytesContext): Throwable = {
+ new ParseException(s"$bytesStr is not a valid byte length literal, " +
+ "expected syntax: DIGIT+ ('B' | 'K' | 'M' | 'G')", ctx)
+ }
+
+ def invalidEscapeStringError(ctx: PredicateContext): Throwable = {
+ new ParseException("Invalid escape string. Escape string must contain only one character.", ctx)
+ }
+
+ def trimOptionUnsupportedError(trimOption: Int, ctx: TrimContext): Throwable = {
+ new ParseException("Function trim doesn't support with " +
+ s"type $trimOption. Please use BOTH, LEADING or TRAILING as trim type", ctx)
+ }
+
+ def functionNameUnsupportedError(functionName: String, ctx: ParserRuleContext): Throwable = {
+ new ParseException(s"Unsupported function name '$functionName'", ctx)
+ }
+
+ def cannotParseValueTypeError(
+ valueType: String, value: String, ctx: TypeConstructorContext): Throwable = {
+ new ParseException(s"Cannot parse the $valueType value: $value", ctx)
+ }
+
+ def cannotParseIntervalValueError(value: String, ctx: TypeConstructorContext): Throwable = {
+ new ParseException(s"Cannot parse the INTERVAL value: $value", ctx)
+ }
+
+ def literalValueTypeUnsupportedError(
+ valueType: String, ctx: TypeConstructorContext): Throwable = {
+ new ParseException(s"Literals of type '$valueType' are currently not supported.", ctx)
+ }
+
+ def parsingValueTypeError(
+ e: IllegalArgumentException, valueType: String, ctx: TypeConstructorContext): Throwable = {
+ val message = Option(e.getMessage).getOrElse(s"Exception parsing $valueType")
+ new ParseException(message, ctx)
+ }
+
+ def invalidNumericLiteralRangeError(rawStrippedQualifier: String, minValue: BigDecimal,
+ maxValue: BigDecimal, typeName: String, ctx: NumberContext): Throwable = {
+ new ParseException(s"Numeric literal $rawStrippedQualifier does not " +
+ s"fit in range [$minValue, $maxValue] for type $typeName", ctx)
+ }
+
+ def moreThanOneFromToUnitInIntervalLiteralError(ctx: ParserRuleContext): Throwable = {
+ new ParseException("Can only have a single from-to unit in the interval literal syntax", ctx)
+ }
+
+ def invalidIntervalLiteralError(ctx: IntervalContext): Throwable = {
+ new ParseException("at least one time unit should be given for interval literal", ctx)
+ }
+
+ def invalidIntervalFormError(value: String, ctx: MultiUnitsIntervalContext): Throwable = {
+ new ParseException("Can only use numbers in the interval value part for" +
+ s" multiple unit value pairs interval form, but got invalid value: $value", ctx)
+ }
+
+ def invalidFromToUnitValueError(ctx: IntervalValueContext): Throwable = {
+ new ParseException("The value of from-to unit must be a string", ctx)
+ }
+
+ def fromToIntervalUnsupportedError(
+ from: String, to: String, ctx: ParserRuleContext): Throwable = {
+ new ParseException(s"Intervals FROM $from TO $to are not supported.", ctx)
+ }
+
+ def mixedIntervalUnitsError(literal: String, ctx: ParserRuleContext): Throwable = {
+ new ParseException(s"Cannot mix year-month and day-time fields: $literal", ctx)
+ }
+
+ def dataTypeUnsupportedError(dataType: String, ctx: PrimitiveDataTypeContext): Throwable = {
+ new ParseException(s"DataType $dataType is not supported.", ctx)
+ }
+
+ def partitionTransformNotExpectedError(
+ name: String, describe: String, ctx: ApplyTransformContext): Throwable = {
+ new ParseException(s"Expected a column reference for transform $name: $describe", ctx)
+ }
+
+ def tooManyArgumentsForTransformError(name: String, ctx: ApplyTransformContext): Throwable = {
+ new ParseException(s"Too many arguments for transform $name", ctx)
+ }
+
+ def notEnoughArgumentsForTransformError(name: String, ctx: ApplyTransformContext): Throwable = {
+ new ParseException(s"Not enough arguments for transform $name", ctx)
+ }
+
+ def invalidBucketsNumberError(describe: String, ctx: ApplyTransformContext): Throwable = {
+ new ParseException(s"Invalid number of buckets: $describe", ctx)
+ }
+
+ def invalidTransformArgumentError(ctx: TransformArgumentContext): Throwable = {
+ new ParseException("Invalid transform argument", ctx)
+ }
+
+ def cannotCleanReservedNamespacePropertyError(
+ property: String, ctx: ParserRuleContext, msg: String): Throwable = {
+ new ParseException(s"$property is a reserved namespace property, $msg.", ctx)
+ }
+
+ def cannotCleanReservedTablePropertyError(
+ property: String, ctx: ParserRuleContext, msg: String): Throwable = {
+ new ParseException(s"$property is a reserved table property, $msg.", ctx)
+ }
+
+ def duplicatedTablePathsFoundError(
+ pathOne: String, pathTwo: String, ctx: ParserRuleContext): Throwable = {
+ new ParseException(s"Duplicated table paths found: '$pathOne' and '$pathTwo'. LOCATION" +
+ s" and the case insensitive key 'path' in OPTIONS are all used to indicate the custom" +
+ s" table path, you can only specify one of them.", ctx)
+ }
+
+ def storedAsAndStoredByBothSpecifiedError(ctx: CreateFileFormatContext): Throwable = {
+ new ParseException("Expected either STORED AS or STORED BY, not both", ctx)
+ }
+
+ def operationInHiveStyleCommandUnsupportedError(operation: String,
+ command: String, ctx: StatementContext, msgOpt: Option[String] = None): Throwable = {
+ val basicError = s"$operation is not supported in Hive-style $command"
+ val msg = if (msgOpt.isDefined) {
+ s"$basicError, ${msgOpt.get}."
+ } else {
+ basicError
+ }
+ new ParseException(msg, ctx)
+ }
+
+ def operationNotAllowedError(message: String, ctx: ParserRuleContext): Throwable = {
+ new ParseException(s"Operation not allowed: $message", ctx)
+ }
+
+ def computeStatisticsNotExpectedError(ctx: IdentifierContext): Throwable = {
+ new ParseException(s"Expected `NOSCAN` instead of `${ctx.getText}`", ctx)
+ }
+
+ def addCatalogInCacheTableAsSelectNotAllowedError(
+ quoted: String, ctx: CacheTableContext): Throwable = {
+ new ParseException(s"It is not allowed to add catalog/namespace prefix $quoted to " +
+ "the table name in CACHE TABLE AS SELECT", ctx)
+ }
+
+ def showFunctionsUnsupportedError(identifier: String, ctx: IdentifierContext): Throwable = {
+ new ParseException(s"SHOW $identifier FUNCTIONS not supported", ctx)
+ }
+
+ def duplicateCteDefinitionNamesError(duplicateNames: String, ctx: CtesContext): Throwable = {
+ new ParseException(s"CTE definition can't have duplicate names: $duplicateNames.", ctx)
+ }
+
+ def sqlStatementUnsupportedError(sqlText: String, position: Origin): Throwable = {
+ new ParseException(Option(sqlText), "Unsupported SQL statement", position, position)
+ }
+
+ def unquotedIdentifierError(ident: String, ctx: ErrorIdentContext): Throwable = {
+ new ParseException(s"Possibly unquoted identifier $ident detected. " +
+ s"Please consider quoting it with back-quotes as `$ident`", ctx)
+ }
+
+ def duplicateClausesError(clauseName: String, ctx: ParserRuleContext): Throwable = {
+ new ParseException(s"Found duplicate clauses: $clauseName", ctx)
+ }
+
+ def duplicateKeysError(key: String, ctx: ParserRuleContext): Throwable = {
+ // Found duplicate keys '$key'
+ new ParseException(errorClass = "DUPLICATE_KEY", messageParameters = Array(key), ctx)
+ }
+
+ def intervalValueOutOfRangeError(ctx: IntervalContext): Throwable = {
+ new ParseException("The interval value must be in the range of [-18, +18] hours" +
+ " with second precision", ctx)
+ }
+
+ def createTempTableNotSpecifyProviderError(ctx: CreateTableContext): Throwable = {
+ new ParseException("CREATE TEMPORARY TABLE without a provider is not allowed.", ctx)
+ }
+
+ def useDefinedRecordReaderOrWriterClassesError(ctx: ParserRuleContext): Throwable = {
+ new ParseException(
+ "Unsupported operation: Used defined record reader/writer classes.", ctx)
+ }
+
+ def directoryPathAndOptionsPathBothSpecifiedError(ctx: InsertOverwriteDirContext): Throwable = {
+ new ParseException(
+ "Directory path and 'path' in OPTIONS should be specified one, but not both", ctx)
+ }
+
+ def unsupportedLocalFileSchemeError(ctx: InsertOverwriteDirContext): Throwable = {
+ new ParseException("LOCAL is supported only with file: scheme", ctx)
+ }
+
+ def invalidGroupingSetError(element: String, ctx: GroupingAnalyticsContext): Throwable = {
+ new ParseException(s"Empty set in $element grouping sets is not supported.", ctx)
+ }
+}
+
+/**
+ * A container for holding named common table expressions (CTEs) and a query plan.
+ * This operator will be removed during analysis and the relations will be substituted into child.
+ *
+ * @param child The final query of this CTE.
+ * @param cteRelations A sequence of pair (alias, the CTE definition) that this CTE defined
+ * Each CTE can see the base tables and the previously defined CTEs only.
+ */
+case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output
+
+ override def simpleString(maxFields: Int): String = {
+ val cteAliases = truncatedString(cteRelations.map(_._1), "[", ", ", "]", maxFields)
+ s"CTE $cteAliases"
+ }
+
+ override def innerChildren: Seq[LogicalPlan] = cteRelations.map(_._2)
+
+ def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = this
+}
diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlParser.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlParser.scala
new file mode 100644
index 0000000000000..59ef8dfe0969b
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlParser.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.parser
+
+import org.antlr.v4.runtime._
+import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException}
+import org.antlr.v4.runtime.tree.TerminalNodeImpl
+import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser.{NonReservedContext, QuotedIdentifierContext}
+import org.apache.hudi.spark.sql.parser.{HoodieSqlBaseBaseListener, HoodieSqlBaseLexer, HoodieSqlBaseParser}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.parser.{ParseErrorListener, ParseException, ParserInterface}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.trees.Origin
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{AnalysisException, SparkSession}
+
+class HoodieSpark3_2ExtendedSqlParser(session: SparkSession, delegate: ParserInterface)
+ extends ParserInterface with Logging {
+
+ private lazy val conf = session.sqlContext.conf
+ private lazy val builder = new HoodieSpark3_2ExtendedSqlAstBuilder(conf, delegate)
+
+ override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
+ builder.visit(parser.singleStatement()) match {
+ case plan: LogicalPlan => plan
+ case _=> delegate.parsePlan(sqlText)
+ }
+ }
+
+ override def parseExpression(sqlText: String): Expression = delegate.parseExpression(sqlText)
+
+ override def parseTableIdentifier(sqlText: String): TableIdentifier =
+ delegate.parseTableIdentifier(sqlText)
+
+ override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier =
+ delegate.parseFunctionIdentifier(sqlText)
+
+ override def parseTableSchema(sqlText: String): StructType = delegate.parseTableSchema(sqlText)
+
+ override def parseDataType(sqlText: String): DataType = delegate.parseDataType(sqlText)
+
+ protected def parse[T](command: String)(toResult: HoodieSqlBaseParser => T): T = {
+ logDebug(s"Parsing command: $command")
+
+ val lexer = new HoodieSqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
+ lexer.removeErrorListeners()
+ lexer.addErrorListener(ParseErrorListener)
+
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new HoodieSqlBaseParser(tokenStream)
+ parser.addParseListener(PostProcessor)
+ parser.removeErrorListeners()
+ parser.addErrorListener(ParseErrorListener)
+// parser.legacy_setops_precedence_enabled = conf.setOpsPrecedenceEnforced
+ parser.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled
+ parser.SQL_standard_keyword_behavior = conf.ansiEnabled
+
+ try {
+ try {
+ // first, try parsing with potentially faster SLL mode
+ parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
+ toResult(parser)
+ }
+ catch {
+ case e: ParseCancellationException =>
+ // if we fail, parse with LL mode
+ tokenStream.seek(0) // rewind input stream
+ parser.reset()
+
+ // Try Again.
+ parser.getInterpreter.setPredictionMode(PredictionMode.LL)
+ toResult(parser)
+ }
+ }
+ catch {
+ case e: ParseException if e.command.isDefined =>
+ throw e
+ case e: ParseException =>
+ throw e.withCommand(command)
+ case e: AnalysisException =>
+ val position = Origin(e.line, e.startPosition)
+ throw new ParseException(Option(command), e.message, position, position)
+ }
+ }
+
+ override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
+ delegate.parseMultipartIdentifier(sqlText)
+ }
+}
+
+/**
+ * Fork from `org.apache.spark.sql.catalyst.parser.UpperCaseCharStream`.
+ */
+class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream {
+ override def consume(): Unit = wrapped.consume
+ override def getSourceName(): String = wrapped.getSourceName
+ override def index(): Int = wrapped.index
+ override def mark(): Int = wrapped.mark
+ override def release(marker: Int): Unit = wrapped.release(marker)
+ override def seek(where: Int): Unit = wrapped.seek(where)
+ override def size(): Int = wrapped.size
+
+ override def getText(interval: Interval): String = {
+ // ANTLR 4.7's CodePointCharStream implementations have bugs when
+ // getText() is called with an empty stream, or intervals where
+ // the start > end. See
+ // https://github.com/antlr/antlr4/commit/ac9f7530 for one fix
+ // that is not yet in a released ANTLR artifact.
+ if (size() > 0 && (interval.b - interval.a >= 0)) {
+ wrapped.getText(interval)
+ } else {
+ ""
+ }
+ }
+ // scalastyle:off
+ override def LA(i: Int): Int = {
+ // scalastyle:on
+ val la = wrapped.LA(i)
+ if (la == 0 || la == IntStream.EOF) la
+ else Character.toUpperCase(la)
+ }
+}
+
+/**
+ * Fork from `org.apache.spark.sql.catalyst.parser.PostProcessor`.
+ */
+case object PostProcessor extends HoodieSqlBaseBaseListener {
+
+ /** Remove the back ticks from an Identifier. */
+ override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = {
+ replaceTokenByIdentifier(ctx, 1) { token =>
+ // Remove the double back ticks in the string.
+ token.setText(token.getText.replace("``", "`"))
+ token
+ }
+ }
+
+ /** Treat non-reserved keywords as Identifiers. */
+ override def exitNonReserved(ctx: NonReservedContext): Unit = {
+ replaceTokenByIdentifier(ctx, 0)(identity)
+ }
+
+ private def replaceTokenByIdentifier(
+ ctx: ParserRuleContext,
+ stripMargins: Int)(
+ f: CommonToken => CommonToken = identity): Unit = {
+ val parent = ctx.getParent
+ parent.removeLastChild()
+ val token = ctx.getChild(0).getPayload.asInstanceOf[Token]
+ val newToken = new CommonToken(
+ new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
+ HoodieSqlBaseParser.IDENTIFIER,
+ token.getChannel,
+ token.getStartIndex + stripMargins,
+ token.getStopIndex - stripMargins)
+ parent.addChild(new TerminalNodeImpl(f(newToken)))
+ }
+}
diff --git a/pom.xml b/pom.xml
index a0f7c2c6e303e..c277c3037b0d2 100644
--- a/pom.xml
+++ b/pom.xml
@@ -120,6 +120,8 @@
1.14.3
2.4.4
3.2.1
+ 3.1.2
+ 3.2.1
hudi-spark2
hudi-spark2-common
1.8.2
@@ -1584,6 +1586,7 @@
spark3
+ ${spark3.2.version}
${spark3.version}
${spark3.version}
${scala12.version}
@@ -1612,7 +1615,7 @@
spark3.1.x
- 3.1.2
+ ${spark3.1.version}
${spark3.version}
${spark3.version}
${scala12.version}
diff --git a/style/scalastyle.xml b/style/scalastyle.xml
index 74d7b9d73a203..a1b4cdbb6dafa 100644
--- a/style/scalastyle.xml
+++ b/style/scalastyle.xml
@@ -27,7 +27,7 @@
-
+
@@ -113,7 +113,7 @@
-
+