diff --git a/spark/src/main/java/org/apache/iceberg/spark/SparkSchemaUtil.java b/spark/src/main/java/org/apache/iceberg/spark/SparkSchemaUtil.java index b503ba634d85..321050dceb74 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/SparkSchemaUtil.java +++ b/spark/src/main/java/org/apache/iceberg/spark/SparkSchemaUtil.java @@ -30,6 +30,7 @@ import org.apache.iceberg.relocated.com.google.common.base.Splitter; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.math.LongMath; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; @@ -37,7 +38,6 @@ import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.catalog.Column; import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; /** @@ -280,23 +280,23 @@ private static PartitionSpec identitySpec(Schema schema, List partitionN } /** - * estimate approximate table size based on spark schema and total records. + * Estimate approximate table size based on Spark schema and total records. * - * @param tableSchema spark schema + * @param tableSchema Spark schema * @param totalRecords total records in the table - * @return approxiate size based on table schema + * @return approximate size based on table schema */ public static long estimateSize(StructType tableSchema, long totalRecords) { if (totalRecords == Long.MAX_VALUE) { return totalRecords; } - long approximateSize = 0; - for (StructField sparkField : tableSchema.fields()) { - approximateSize += sparkField.dataType().defaultSize(); + long result; + try { + result = LongMath.checkedMultiply(tableSchema.defaultSize(), totalRecords); + } catch (ArithmeticException e) { + result = Long.MAX_VALUE; } - - long result = approximateSize * totalRecords; - return result > 0 ? result : Long.MAX_VALUE; + return result; } } diff --git a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java index acd02037aced..e9eda0b29394 100644 --- a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java +++ b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java @@ -207,22 +207,20 @@ public Statistics estimateStatistics() { LOG.debug("using table metadata to estimate table statistics"); long totalRecords = PropertyUtil.propertyAsLong(table.currentSnapshot().summary(), SnapshotSummary.TOTAL_RECORDS_PROP, Long.MAX_VALUE); - Schema projectedSchema = expectedSchema != null ? expectedSchema : table.schema(); return new Stats( - SparkSchemaUtil.estimateSize(SparkSchemaUtil.convert(projectedSchema), totalRecords), + SparkSchemaUtil.estimateSize(readSchema(), totalRecords), totalRecords); } - long sizeInBytes = 0L; long numRows = 0L; for (CombinedScanTask task : tasks()) { for (FileScanTask file : task.files()) { - sizeInBytes += file.length(); numRows += file.file().recordCount(); } } + long sizeInBytes = SparkSchemaUtil.estimateSize(readSchema(), numRows); return new Stats(sizeInBytes, numRows); }