Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.connector.catalog;

import java.util.Map;
import java.util.Set;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.analysis.*;
Expand Down Expand Up @@ -52,6 +53,11 @@ public String name() {
@Override
public final void initialize(String name, CaseInsensitiveStringMap options) {}

@Override
public Set<TableCatalogCapability> capabilities() {
return asTableCatalog().capabilities();
}

@Override
public String[] defaultNamespace() {
return delegate.defaultNamespace();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,31 @@
public enum TableCatalogCapability {

/**
* Signals that the TableCatalog supports defining generated columns upon table creation in SQL.
* <p>
* Without this capability, any create/replace table statements with a generated column defined
* in the table schema will throw an exception during analysis.
* <p>
* A generated column is defined with syntax: {@code colName colType GENERATED ALWAYS AS (expr)}
* <p>
* Generation expression are included in the column definition for APIs like
* {@link TableCatalog#createTable}.
* See {@link Column#generationExpression()}.
*/
SUPPORTS_CREATE_TABLE_WITH_GENERATED_COLUMNS
* Signals that the TableCatalog supports defining generated columns upon table creation in SQL.
* <p>
* Without this capability, any create/replace table statements with a generated column defined
* in the table schema will throw an exception during analysis.
* <p>
* A generated column is defined with syntax: {@code colName colType GENERATED ALWAYS AS (expr)}
* <p>
* Generation expression are included in the column definition for APIs like
* {@link TableCatalog#createTable}.
* See {@link Column#generationExpression()}.
*/
SUPPORTS_CREATE_TABLE_WITH_GENERATED_COLUMNS,

/**
* Signals that the TableCatalog supports defining column default value as expression in
* CREATE/REPLACE/ALTER TABLE.
* <p>
* Without this capability, any CREATE/REPLACE/ALTER TABLE statement with a column default value
* defined in the table schema will throw an exception during analysis.
* <p>
* A column default value is defined with syntax: {@code colName colType DEFAULT expr}
* <p>
* Column default value expression is included in the column definition for APIs like
* {@link TableCatalog#createTable}.
* See {@link Column#defaultValue()}.
*/
SUPPORT_COLUMN_DEFAULT_VALUE
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.BuiltInFunctionCatalog
import org.apache.spark.sql.connector.catalog.{CatalogManager, TableCatalog, TableCatalogCapability}
import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier, TableCatalog, TableCatalogCapability}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructField, StructType}
Expand Down Expand Up @@ -182,12 +182,13 @@ object GeneratedColumn {
def validateGeneratedColumns(
schema: StructType,
catalog: TableCatalog,
ident: Seq[String],
ident: Identifier,
statementType: String): Unit = {
if (hasGeneratedColumns(schema)) {
if (!catalog.capabilities().contains(
TableCatalogCapability.SUPPORTS_CREATE_TABLE_WITH_GENERATED_COLUMNS)) {
throw QueryCompilationErrors.generatedColumnsUnsupported(ident)
throw QueryCompilationErrors.unsupportedTableOperationError(
catalog, ident, "generated columns")
}
GeneratedColumn.verifyGeneratedColumns(schema, statementType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.optimizer.ConstantFolding
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
import org.apache.spark.sql.connector.catalog.{CatalogManager, FunctionCatalog, Identifier}
import org.apache.spark.sql.connector.catalog.{CatalogManager, FunctionCatalog, Identifier, TableCatalog, TableCatalogCapability}
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -90,39 +90,15 @@ object ResolveDefaultColumns {
* EXISTS_DEFAULT metadata for such columns where the value is not present in storage.
*
* @param tableSchema represents the names and types of the columns of the statement to process.
* @param tableProvider provider of the target table to store default values for, if any.
* @param statementType name of the statement being processed, such as INSERT; useful for errors.
* @param addNewColumnToExistingTable true if the statement being processed adds a new column to
* a table that already exists.
* @return a copy of `tableSchema` with field metadata updated with the constant-folded values.
*/
def constantFoldCurrentDefaultsToExistDefaults(
tableSchema: StructType,
tableProvider: Option[String],
statementType: String,
addNewColumnToExistingTable: Boolean): StructType = {
statementType: String): StructType = {
if (SQLConf.get.enableDefaultColumns) {
val keywords: Array[String] =
SQLConf.get.getConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS)
.toLowerCase().split(",").map(_.trim)
val allowedTableProviders: Array[String] =
keywords.map(_.stripSuffix("*"))
val addColumnExistingTableBannedProviders: Array[String] =
keywords.filter(_.endsWith("*")).map(_.stripSuffix("*"))
val givenTableProvider: String = tableProvider.getOrElse("").toLowerCase()
val newFields: Seq[StructField] = tableSchema.fields.map { field =>
if (field.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)) {
// Make sure that the target table has a provider that supports default column values.
if (!allowedTableProviders.contains(givenTableProvider)) {
throw QueryCompilationErrors
.defaultReferencesNotAllowedInDataSource(statementType, givenTableProvider)
}
if (addNewColumnToExistingTable &&
givenTableProvider.nonEmpty &&
addColumnExistingTableBannedProviders.contains(givenTableProvider)) {
throw QueryCompilationErrors
.addNewDefaultColumnToExistingTableNotAllowed(statementType, givenTableProvider)
}
val analyzed: Expression = analyze(field, statementType)
val newMetadata: Metadata = new MetadataBuilder().withMetadata(field.metadata)
.putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, analyzed.sql).build()
Expand All @@ -137,6 +113,47 @@ object ResolveDefaultColumns {
}
}

// Fails if the given catalog does not support column default value.
def validateCatalogForDefaultValue(
schema: StructType,
catalog: TableCatalog,
ident: Identifier): Unit = {
if (SQLConf.get.enableDefaultColumns &&
schema.exists(_.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)) &&
!catalog.capabilities().contains(TableCatalogCapability.SUPPORT_COLUMN_DEFAULT_VALUE)) {
throw QueryCompilationErrors.unsupportedTableOperationError(
catalog, ident, "column default value")
}
}

// Fails if the given table provider of the session catalog does not support column default value.
def validateTableProviderForDefaultValue(
schema: StructType,
tableProvider: Option[String],
statementType: String,
addNewColumnToExistingTable: Boolean): Unit = {
if (SQLConf.get.enableDefaultColumns &&
schema.exists(_.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY))) {
val keywords: Array[String] = SQLConf.get.getConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS)
.toLowerCase().split(",").map(_.trim)
val allowedTableProviders: Array[String] = keywords.map(_.stripSuffix("*"))
val addColumnExistingTableBannedProviders: Array[String] =
keywords.filter(_.endsWith("*")).map(_.stripSuffix("*"))
val givenTableProvider: String = tableProvider.getOrElse("").toLowerCase()
// Make sure that the target table has a provider that supports default column values.
if (!allowedTableProviders.contains(givenTableProvider)) {
throw QueryCompilationErrors.defaultReferencesNotAllowedInDataSource(
statementType, givenTableProvider)
}
if (addNewColumnToExistingTable &&
givenTableProvider.nonEmpty &&
addColumnExistingTableBannedProviders.contains(givenTableProvider)) {
throw QueryCompilationErrors.addNewDefaultColumnToExistingTableNotAllowed(
statementType, givenTableProvider)
}
}
}

/**
* Parses and analyzes the DEFAULT column text in `field`, returning an error upon failure.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,9 @@ private[sql] object CatalogV2Util {
val (before, after) = schema.fields.splitAt(fieldIndex + 1)
StructType(before ++ (field +: after))
}
constantFoldCurrentDefaultsToExistDefaults(
validateTableProviderForDefaultValue(
newSchema, tableProvider, statementType, addNewColumnToExistingTable)
constantFoldCurrentDefaultsToExistDefaults(newSchema, statementType)
}

private def replace(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -749,13 +749,28 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase {
messageParameters = Map.empty)
}

def operationOnlySupportedWithV2TableError(
nameParts: Seq[String],
def unsupportedTableOperationError(
catalog: CatalogPlugin,
ident: Identifier,
operation: String): Throwable = {
unsupportedTableOperationError(
catalog.name +: ident.namespace :+ ident.name, operation)
}

def unsupportedTableOperationError(
ident: TableIdentifier,
operation: String): Throwable = {
unsupportedTableOperationError(
Seq(ident.catalog.get, ident.database.get, ident.table), operation)
}

private def unsupportedTableOperationError(
qualifiedTableName: Seq[String],
operation: String): Throwable = {
new AnalysisException(
errorClass = "UNSUPPORTED_FEATURE.TABLE_OPERATION",
messageParameters = Map(
"tableName" -> toSQLId(nameParts),
"tableName" -> toSQLId(qualifiedTableName),
"operation" -> operation))
}

Expand Down Expand Up @@ -3405,16 +3420,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase {
}
}

def generatedColumnsUnsupported(nameParts: Seq[String]): AnalysisException = {
new AnalysisException(
errorClass = "UNSUPPORTED_FEATURE.TABLE_OPERATION",
messageParameters = Map(
"tableName" -> toSQLId(nameParts),
"operation" -> "generated columns"
)
)
}

def ambiguousLateralColumnAliasError(name: String, numOfMatches: Int): Throwable = {
new AnalysisException(
errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually {
// disabled.
withSQLConf(SQLConf.ENABLE_DEFAULT_COLUMNS.key -> "false") {
val result: StructType = ResolveDefaultColumns.constantFoldCurrentDefaultsToExistDefaults(
db1tbl3.schema, db1tbl3.provider, "CREATE TABLE", false)
db1tbl3.schema, "CREATE TABLE")
val columnEWithFeatureDisabled: StructField = findField("e", result)
// No constant-folding has taken place to the EXISTS_DEFAULT metadata.
assert(!columnEWithFeatureDisabled.metadata.contains("EXISTS_DEFAULT"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ class BasicInMemoryTableCatalog extends TableCatalog {
class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamespaces {

override def capabilities: java.util.Set[TableCatalogCapability] = {
Set(TableCatalogCapability.SUPPORTS_CREATE_TABLE_WITH_GENERATED_COLUMNS).asJava
Set(
TableCatalogCapability.SUPPORT_COLUMN_DEFAULT_VALUE,
TableCatalogCapability.SUPPORTS_CREATE_TABLE_WITH_GENERATED_COLUMNS
).asJava
}

protected def allNamespaces: Seq[Seq[String]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,8 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
case AddColumns(ResolvedV1TableIdentifier(ident), cols) =>
cols.foreach { c =>
if (c.name.length > 1) {
throw QueryCompilationErrors.operationOnlySupportedWithV2TableError(
Seq(ident.catalog.get, ident.database.get, ident.table),
"ADD COLUMN with qualified column")
throw QueryCompilationErrors.unsupportedTableOperationError(
ident, "ADD COLUMN with qualified column")
}
if (!c.nullable) {
throw QueryCompilationErrors.addColumnWithV1TableCannotSpecifyNotNullError
Expand All @@ -64,24 +63,20 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
AlterTableAddColumnsCommand(ident, cols.map(convertToStructField))

case ReplaceColumns(ResolvedV1TableIdentifier(ident), _) =>
throw QueryCompilationErrors.operationOnlySupportedWithV2TableError(
Seq(ident.catalog.get, ident.database.get, ident.table),
"REPLACE COLUMNS")
throw QueryCompilationErrors.unsupportedTableOperationError(ident, "REPLACE COLUMNS")

case a @ AlterColumn(ResolvedTable(catalog, ident, table: V1Table, _), _, _, _, _, _, _)
if isSessionCatalog(catalog) =>
if (a.column.name.length > 1) {
throw QueryCompilationErrors.operationOnlySupportedWithV2TableError(
Seq(catalog.name, ident.namespace()(0), ident.name),
"ALTER COLUMN with qualified column")
throw QueryCompilationErrors.unsupportedTableOperationError(
catalog, ident, "ALTER COLUMN with qualified column")
}
if (a.nullable.isDefined) {
throw QueryCompilationErrors.alterColumnWithV1TableCannotSpecifyNotNullError
}
if (a.position.isDefined) {
throw QueryCompilationErrors.operationOnlySupportedWithV2TableError(
Seq(catalog.name, ident.namespace()(0), ident.name),
"ALTER COLUMN ... FIRST | ALTER")
throw QueryCompilationErrors.unsupportedTableOperationError(
catalog, ident, "ALTER COLUMN ... FIRST | ALTER")
}
val builder = new MetadataBuilder
// Add comment to metadata
Expand All @@ -105,14 +100,10 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
AlterTableChangeColumnCommand(table.catalogTable.identifier, colName, newColumn)

case RenameColumn(ResolvedV1TableIdentifier(ident), _, _) =>
throw QueryCompilationErrors.operationOnlySupportedWithV2TableError(
Seq(ident.catalog.get, ident.database.get, ident.table),
"RENAME COLUMN")
throw QueryCompilationErrors.unsupportedTableOperationError(ident, "RENAME COLUMN")

case DropColumns(ResolvedV1TableIdentifier(ident), _, _) =>
throw QueryCompilationErrors.operationOnlySupportedWithV2TableError(
Seq(ident.catalog.get, ident.database.get, ident.table),
"DROP COLUMN")
throw QueryCompilationErrors.unsupportedTableOperationError(ident, "DROP COLUMN")

case SetTableProperties(ResolvedV1TableIdentifier(ident), props) =>
AlterTableSetPropertiesCommand(ident, props, isView = false)
Expand Down Expand Up @@ -204,19 +195,17 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
case c @ ReplaceTable(ResolvedV1Identifier(ident), _, _, _, _) =>
val provider = c.tableSpec.provider.getOrElse(conf.defaultDataSourceName)
if (!isV2Provider(provider)) {
throw QueryCompilationErrors.operationOnlySupportedWithV2TableError(
Seq(ident.catalog.get, ident.database.get, ident.table),
"REPLACE TABLE")
throw QueryCompilationErrors.unsupportedTableOperationError(
ident, "REPLACE TABLE")
} else {
c
}

case c @ ReplaceTableAsSelect(ResolvedV1Identifier(ident), _, _, _, _, _, _) =>
val provider = c.tableSpec.provider.getOrElse(conf.defaultDataSourceName)
if (!isV2Provider(provider)) {
throw QueryCompilationErrors.operationOnlySupportedWithV2TableError(
Seq(ident.catalog.get, ident.database.get, ident.table),
"REPLACE TABLE AS SELECT")
throw QueryCompilationErrors.unsupportedTableOperationError(
ident, "REPLACE TABLE AS SELECT")
} else {
c
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,11 @@ case class AlterTableAddColumnsCommand(
sparkSession: SparkSession, tableProvider: Option[String]): Seq[StructField] = {
colsToAdd.map { col: StructField =>
if (col.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)) {
val schema = StructType(Array(col))
ResolveDefaultColumns.validateTableProviderForDefaultValue(
schema, tableProvider, "ALTER TABLE ADD COLUMNS", true)
val foldedStructType = ResolveDefaultColumns.constantFoldCurrentDefaultsToExistDefaults(
StructType(Array(col)), tableProvider, "ALTER TABLE ADD COLUMNS", true)
schema, "ALTER TABLE ADD COLUMNS")
foldedStructType.fields(0)
} else {
col
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,15 @@ case class DataSourceAnalysis(analyzer: Analyzer) extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case CreateTable(tableDesc, mode, None) if DDLUtils.isDatasourceTable(tableDesc) =>
ResolveDefaultColumns.validateTableProviderForDefaultValue(
tableDesc.schema, tableDesc.provider, "CREATE TABLE", false)
val newSchema: StructType =
ResolveDefaultColumns.constantFoldCurrentDefaultsToExistDefaults(
tableDesc.schema, tableDesc.provider, "CREATE TABLE", false)
tableDesc.schema, "CREATE TABLE")

if (GeneratedColumn.hasGeneratedColumns(newSchema)) {
throw QueryCompilationErrors.generatedColumnsUnsupported(
Seq(tableDesc.identifier.catalog.get, tableDesc.identifier.database.get,
tableDesc.identifier.table))
throw QueryCompilationErrors.unsupportedTableOperationError(
tableDesc.identifier, "generated columns")
}

val newTableDesc = tableDesc.copy(schema = newSchema)
Expand Down
Loading