diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/FormatTableStatistics.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/FormatTableStatistics.scala new file mode 100644 index 000000000000..cfe3185bacf9 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/FormatTableStatistics.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark + +import org.apache.paimon.table.format.FormatDataSplit +import org.apache.paimon.types.RowType + +import org.apache.spark.sql.connector.read.Statistics + +import java.util.OptionalLong + +import scala.collection.JavaConverters._ + +case class FormatTableStatistics[T <: PaimonFormatTableBaseScan](scan: T) extends Statistics { + + private lazy val fileTotalSize: Long = + scan.getOriginSplits.map(_.asInstanceOf[FormatDataSplit]).map(_.length()).sum + + override def sizeInBytes(): OptionalLong = { + val size = fileTotalSize / + estimateRowSize(scan.tableRowType) * + estimateRowSize(scan.readTableRowType) + OptionalLong.of(size) + } + + private def estimateRowSize(rowType: RowType): Long = { + rowType.getFields.asScala.map(_.`type`().defaultSize().toLong).sum + } + + override def numRows(): OptionalLong = OptionalLong.empty() +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonFormatTableBaseScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonFormatTableBaseScan.scala index b04e21c99af4..e47b8aedc2bf 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonFormatTableBaseScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonFormatTableBaseScan.scala @@ -24,7 +24,7 @@ import org.apache.paimon.table.FormatTable import org.apache.paimon.table.source.Split import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} -import org.apache.spark.sql.connector.read.Batch +import org.apache.spark.sql.connector.read.{Batch, Statistics, SupportsReportStatistics} import org.apache.spark.sql.types.StructType import scala.collection.JavaConverters._ @@ -36,6 +36,7 @@ abstract class PaimonFormatTableBaseScan( filters: Seq[Predicate], pushDownLimit: Option[Int]) extends ColumnPruningAndPushDown + with SupportsReportStatistics with ScanHelper { override val coreOptions: CoreOptions = CoreOptions.fromMap(table.options()) @@ -65,6 +66,10 @@ abstract class PaimonFormatTableBaseScan( PaimonBatch(lazyInputPartitions, readBuilder, coreOptions.blobAsDescriptor(), metadataColumns) } + override def estimateStatistics(): Statistics = { + FormatTableStatistics(this) + } + override def supportedCustomMetrics: Array[CustomMetric] = { Array( PaimonNumSplitMetric(), diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/FormatTableTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/FormatTableTestBase.scala index cb469f4da0b3..23b2873817d2 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/FormatTableTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/FormatTableTestBase.scala @@ -155,4 +155,15 @@ abstract class FormatTableTestBase extends PaimonHiveTestBase { assert(row.toString().contains("'csv.field-delimiter' = ';'")) } } + + test("Format table: broadcast join for small table") { + withTable("t") { + sql("CREATE TABLE t1 (f0 INT, f1 INT) USING CSV TBLPROPERTIES ('file.compression'='none')") + sql("CREATE TABLE t2 (f0 INT, f2 INT) USING CSV TBLPROPERTIES ('file.compression'='none')") + sql("INSERT INTO t1 VALUES (1, 1)") + sql("INSERT INTO t2 VALUES (1, 1)") + val df = sql("SELECT t1.f0, t1.f1, t2.f2 FROM t1, t2 WHERE t1.f0 = t2.f0") + assert(df.queryExecution.executedPlan.toString().contains("BroadcastExchange")) + } + } }