diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index 6b4f29bea757..8b7392e71249 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.{Literal => ExprLiteral} -import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, ReplaceExpressions} +import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, Optimizer} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION @@ -290,7 +290,9 @@ object ResolveDefaultColumns extends QueryErrorsBase val analyzer: Analyzer = DefaultColumnAnalyzer val analyzed = analyzer.execute(Project(Seq(Alias(parsed, colName)()), OneRowRelation())) analyzer.checkAnalysis(analyzed) - ConstantFolding(ReplaceExpressions(analyzed)) + // Eagerly execute finish-analysis and constant-folding rules before checking whether the + // expression is foldable and resolved. + ConstantFolding(DefaultColumnOptimizer.FinishAnalysis(analyzed)) } catch { case ex: AnalysisException => throw QueryCompilationErrors.defaultValuesUnresolvedExprError( @@ -517,6 +519,11 @@ object ResolveDefaultColumns extends QueryErrorsBase new CatalogManager(BuiltInFunctionCatalog, BuiltInFunctionCatalog.v1Catalog)) { } + /** + * This is an Optimizer for convert default column expressions to foldable literals. + */ + object DefaultColumnOptimizer extends Optimizer(DefaultColumnAnalyzer.catalogManager) + /** * This is a FunctionCatalog for performing analysis using built-in functions only. It is a helper * for the DefaultColumnAnalyzer above. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ResolveDefaultColumnsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ResolveDefaultColumnsSuite.scala index bca147279993..e3ebbadbb829 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ResolveDefaultColumnsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ResolveDefaultColumnsSuite.scala @@ -287,4 +287,25 @@ class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { checkAnswer(sql("select v from t"), sql("select parse_json('1')").collect()) } } + + test("SPARK-49054: Create table with current_user() default") { + val tableName = "test_current_user" + val user = spark.sparkContext.sparkUser + withTable(tableName) { + sql(s"CREATE TABLE $tableName(i int, s string default current_user()) USING parquet") + sql(s"INSERT INTO $tableName (i) VALUES ((0))") + checkAnswer(sql(s"SELECT * FROM $tableName"), Seq(Row(0, user))) + } + } + + test("SPARK-49054: Alter table with current_user() default") { + val tableName = "test_current_user" + val user = spark.sparkContext.sparkUser + withTable(tableName) { + sql(s"CREATE TABLE $tableName(i int, s string) USING parquet") + sql(s"ALTER TABLE $tableName ALTER COLUMN s SET DEFAULT current_user()") + sql(s"INSERT INTO $tableName (i) VALUES ((0))") + checkAnswer(sql(s"SELECT * FROM $tableName"), Seq(Row(0, user))) + } + } }