diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index d99c170fae579..03b42a760eac0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchDatabaseException, NoSuchNamespaceException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.catalyst.statsEstimation.StatsEstimationTestBase import org.apache.spark.sql.catalyst.util.{DateTimeUtils, ResolveDefaultColumns} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME @@ -47,7 +49,7 @@ import org.apache.spark.unsafe.types.UTF8String abstract class DataSourceV2SQLSuite extends InsertIntoTests(supportsDynamicOverwrite = true, includeSQLOnlyTests = true) - with DeleteFromTests with DatasourceV2SQLBase { + with DeleteFromTests with DatasourceV2SQLBase with StatsEstimationTestBase { protected val v2Source = classOf[FakeV2Provider].getName override protected val v2Format = v2Source @@ -2779,17 +2781,16 @@ class DataSourceV2SQLSuiteV1Filter " (4, null), (5, 'test5')") val df = spark.sql("select * from testcat.test") + val expectedColumnStats = Seq( + "id" -> ColumnStat(Some(5), None, None, Some(0), None, None, None, 2), + "data" -> ColumnStat(Some(3), None, None, Some(3), None, None, None, 2)) df.queryExecution.optimizedPlan.collect { case scan: DataSourceV2ScanRelation => val stats = scan.stats assert(stats.sizeInBytes == 200) assert(stats.rowCount.get == 5) - val colStats = stats.attributeStats.values.toArray - assert(colStats.length == 2) - assert(colStats(0).distinctCount.get == 3) - assert(colStats(0).nullCount.get == 3) - assert(colStats(1).distinctCount.get == 5) - assert(colStats(1).nullCount.get == 0) + assert(stats.attributeStats == + toAttributeMap(expectedColumnStats, df.queryExecution.optimizedPlan)) } }