diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index f55fa2d8f5e8..0d947258e655 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.{Literal => ExprLiteral} -import org.apache.spark.sql.catalyst.optimizer.ConstantFolding +import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, Optimizer} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION @@ -285,7 +285,9 @@ object ResolveDefaultColumns extends QueryErrorsBase with ResolveDefaultColumnsU val analyzer: Analyzer = DefaultColumnAnalyzer val analyzed = analyzer.execute(Project(Seq(Alias(parsed, colName)()), OneRowRelation())) analyzer.checkAnalysis(analyzed) - ConstantFolding(analyzed) + // Eagerly execute finish-analysis and constant-folding rules before checking whether the + // expression is foldable and resolved. + ConstantFolding(DefaultColumnOptimizer.FinishAnalysis(analyzed)) } catch { case ex: AnalysisException => throw QueryCompilationErrors.defaultValuesUnresolvedExprError( @@ -452,6 +454,11 @@ object ResolveDefaultColumns extends QueryErrorsBase with ResolveDefaultColumnsU new CatalogManager(BuiltInFunctionCatalog, BuiltInFunctionCatalog.v1Catalog)) { } + /** + * This is an Optimizer for convert default column expressions to foldable literals. + */ + object DefaultColumnOptimizer extends Optimizer(DefaultColumnAnalyzer.catalogManager) + /** * This is a FunctionCatalog for performing analysis using built-in functions only. It is a helper * for the DefaultColumnAnalyzer above. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ResolveDefaultColumnsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ResolveDefaultColumnsSuite.scala index 00529559a485..79b2f517b060 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ResolveDefaultColumnsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ResolveDefaultColumnsSuite.scala @@ -215,4 +215,25 @@ class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { } } } + + test("SPARK-49054: Create table with current_user() default") { + val tableName = "test_current_user" + val user = spark.sparkContext.sparkUser + withTable(tableName) { + sql(s"CREATE TABLE $tableName(i int, s string default current_user()) USING parquet") + sql(s"INSERT INTO $tableName (i) VALUES ((0))") + checkAnswer(sql(s"SELECT * FROM $tableName"), Seq(Row(0, user))) + } + } + + test("SPARK-49054: Alter table with current_user() default") { + val tableName = "test_current_user" + val user = spark.sparkContext.sparkUser + withTable(tableName) { + sql(s"CREATE TABLE $tableName(i int, s string) USING parquet") + sql(s"ALTER TABLE $tableName ALTER COLUMN s SET DEFAULT current_user()") + sql(s"INSERT INTO $tableName (i) VALUES ((0))") + checkAnswer(sql(s"SELECT * FROM $tableName"), Seq(Row(0, user))) + } + } }