diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index e3eab6f6730f1..996b2566eeb7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -351,6 +351,10 @@ private[sql] object FieldReference { def apply(column: String): NamedReference = { LogicalExpressions.parseReference(column) } + + def column(name: String) : NamedReference = { + FieldReference(Seq(name)) + } } private[sql] final case class SortValue( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 564937feb608d..e734de32d232f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -706,41 +706,45 @@ object DataSourceStrategy if (agg.filter.isEmpty) { agg.aggregateFunction match { case aggregate.Min(PushableColumnWithoutNestedColumn(name)) => - Some(new Min(FieldReference(name))) + Some(new Min(FieldReference.column(name))) case aggregate.Max(PushableColumnWithoutNestedColumn(name)) => - Some(new Max(FieldReference(name))) + Some(new Max(FieldReference.column(name))) case count: aggregate.Count if count.children.length == 1 => count.children.head match { // COUNT(any literal) is the same as COUNT(*) case Literal(_, _) => Some(new CountStar()) case PushableColumnWithoutNestedColumn(name) => - Some(new Count(FieldReference(name), agg.isDistinct)) + Some(new Count(FieldReference.column(name), agg.isDistinct)) case _ => None } case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => - Some(new Sum(FieldReference(name), agg.isDistinct)) + Some(new Sum(FieldReference.column(name), agg.isDistinct)) case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) => - Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference(name)))) + Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference.column(name)))) case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) => - Some(new GeneralAggregateFunc("VAR_POP", agg.isDistinct, Array(FieldReference(name)))) + Some(new GeneralAggregateFunc( + "VAR_POP", agg.isDistinct, Array(FieldReference.column(name)))) case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) => - Some(new GeneralAggregateFunc("VAR_SAMP", agg.isDistinct, Array(FieldReference(name)))) + Some(new GeneralAggregateFunc( + "VAR_SAMP", agg.isDistinct, Array(FieldReference.column(name)))) case aggregate.StddevPop(PushableColumnWithoutNestedColumn(name), _) => - Some(new GeneralAggregateFunc("STDDEV_POP", agg.isDistinct, Array(FieldReference(name)))) + Some(new GeneralAggregateFunc( + "STDDEV_POP", agg.isDistinct, Array(FieldReference.column(name)))) case aggregate.StddevSamp(PushableColumnWithoutNestedColumn(name), _) => - Some(new GeneralAggregateFunc("STDDEV_SAMP", agg.isDistinct, Array(FieldReference(name)))) + Some(new GeneralAggregateFunc( + "STDDEV_SAMP", agg.isDistinct, Array(FieldReference.column(name)))) case aggregate.CovPopulation(PushableColumnWithoutNestedColumn(left), PushableColumnWithoutNestedColumn(right), _) => Some(new GeneralAggregateFunc("COVAR_POP", agg.isDistinct, - Array(FieldReference(left), FieldReference(right)))) + Array(FieldReference.column(left), FieldReference.column(right)))) case aggregate.CovSample(PushableColumnWithoutNestedColumn(left), PushableColumnWithoutNestedColumn(right), _) => Some(new GeneralAggregateFunc("COVAR_SAMP", agg.isDistinct, - Array(FieldReference(left), FieldReference(right)))) + Array(FieldReference.column(left), FieldReference.column(right)))) case aggregate.Corr(PushableColumnWithoutNestedColumn(left), PushableColumnWithoutNestedColumn(right), _) => Some(new GeneralAggregateFunc("CORR", agg.isDistinct, - Array(FieldReference(left), FieldReference(right)))) + Array(FieldReference.column(left), FieldReference.column(right)))) case _ => None } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 39374b6924820..db7b3dc7248f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -118,7 +118,7 @@ object PushDownUtils extends PredicateHelper { def columnAsString(e: Expression): Option[FieldReference] = e match { case PushableColumnWithoutNestedColumn(name) => - Some(FieldReference(name).asInstanceOf[FieldReference]) + Some(FieldReference.column(name).asInstanceOf[FieldReference]) case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 3ae9d0322c6ee..8808321323602 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -80,6 +80,17 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .executeUpdate() conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)") .executeUpdate() + conn.prepareStatement( + "CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2)").executeUpdate() + + // scalastyle:off + conn.prepareStatement( + "CREATE TABLE \"test\".\"person\" (\"名\" INTEGER NOT NULL)").executeUpdate() + // scalastyle:on + conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (1)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (2)").executeUpdate() } } @@ -305,7 +316,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("show tables") { checkAnswer(sql("SHOW TABLES IN h2.test"), Seq(Row("test", "people", false), Row("test", "empty_table", false), - Row("test", "employee", false))) + Row("test", "employee", false), Row("test", "dept", false), Row("test", "person", false))) } test("SQL API: create table as select") { @@ -831,4 +842,32 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df, Seq(Row("alex", 1), Row("amy", 1), Row("cathy", 1), Row("david", 1), Row("jen", 1))) } + + test("column name with composite field") { + checkAnswer(sql("SELECT `dept id` FROM h2.test.dept"), Seq(Row(1), Row(2))) + val df = sql("SELECT COUNT(`dept id`) FROM h2.test.dept") + checkAggregateRemoved(df) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [COUNT(`dept id`)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(2))) + } + + test("column name with non-ascii") { + // scalastyle:off + checkAnswer(sql("SELECT `名` FROM h2.test.person"), Seq(Row(1), Row(2))) + val df = sql("SELECT COUNT(`名`) FROM h2.test.person") + checkAggregateRemoved(df) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [COUNT(`名`)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(2))) + // scalastyle:on + } }