Skip to content

Commit 487faf1

Browse files
wangyumcloud-fan
authored andcommitted
[SPARK-24117][SQL] Unified the getSizePerRow
## What changes were proposed in this pull request? This pr unified the `getSizePerRow` because `getSizePerRow` is used in many places. For example: 1. [LocalRelation.scala#L80](https://github.com/wangyum/spark/blob/f70f46d1e5bc503e9071707d837df618b7696d32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala#L80) 2. [SizeInBytesOnlyStatsPlanVisitor.scala#L36](https://github.com/apache/spark/blob/76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala#L36) ## How was this patch tested? Exist tests Author: Yuming Wang <[email protected]> Closes #21189 from wangyum/SPARK-24117.
1 parent 2f6fe7d commit 487faf1

File tree

7 files changed

+21
-19
lines changed

7 files changed

+21
-19
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.sql.Row
2121
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
2222
import org.apache.spark.sql.catalyst.analysis
2323
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
24+
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
2425
import org.apache.spark.sql.types.{StructField, StructType}
2526

2627
object LocalRelation {
@@ -77,7 +78,7 @@ case class LocalRelation(
7778
}
7879

7980
override def computeStats(): Statistics =
80-
Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length)
81+
Statistics(sizeInBytes = EstimationUtils.getSizePerRow(output) * data.length)
8182

8283
def toSQL(inlineTableName: String): String = {
8384
require(data.nonEmpty)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,13 @@
1717

1818
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
1919

20-
import scala.collection.mutable
2120
import scala.collection.mutable.ArrayBuffer
2221
import scala.math.BigDecimal.RoundingMode
2322

2423
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
2524
import org.apache.spark.sql.catalyst.plans.logical._
2625
import org.apache.spark.sql.types.{DecimalType, _}
2726

28-
2927
object EstimationUtils {
3028

3129
/** Check if each plan has rowCount in its statistics. */
@@ -73,13 +71,12 @@ object EstimationUtils {
7371
AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _)))
7472
}
7573

76-
def getOutputSize(
74+
def getSizePerRow(
7775
attributes: Seq[Attribute],
78-
outputRowCount: BigInt,
7976
attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {
8077
// We assign a generic overhead for a Row object, the actual overhead is different for different
8178
// Row format.
82-
val sizePerRow = 8 + attributes.map { attr =>
79+
8 + attributes.map { attr =>
8380
if (attrStats.get(attr).map(_.avgLen.isDefined).getOrElse(false)) {
8481
attr.dataType match {
8582
case StringType =>
@@ -92,10 +89,15 @@ object EstimationUtils {
9289
attr.dataType.defaultSize
9390
}
9491
}.sum
92+
}
9593

94+
def getOutputSize(
95+
attributes: Seq[Attribute],
96+
outputRowCount: BigInt,
97+
attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {
9698
// Output size can't be zero, or sizeInBytes of BinaryNode will also be zero
9799
// (simple computation of statistics returns product of children).
98-
if (outputRowCount > 0) outputRowCount * sizePerRow else 1
100+
if (outputRowCount > 0) outputRowCount * getSizePerRow(attributes, attrStats) else 1
99101
}
100102

101103
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
3333
private def visitUnaryNode(p: UnaryNode): Statistics = {
3434
// There should be some overhead in Row object, the size should not be zero when there is
3535
// no columns, this help to prevent divide-by-zero error.
36-
val childRowSize = p.child.output.map(_.dataType.defaultSize).sum + 8
37-
val outputRowSize = p.output.map(_.dataType.defaultSize).sum + 8
36+
val childRowSize = EstimationUtils.getSizePerRow(p.child.output)
37+
val outputRowSize = EstimationUtils.getSizePerRow(p.output)
3838
// Assume there will be the same number of rows as child has.
3939
var sizeInBytes = (p.child.stats.sizeInBytes * outputRowSize) / childRowSize
4040
if (sizeInBytes == 0) {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,21 @@ import javax.annotation.concurrent.GuardedBy
2424

2525
import scala.collection.JavaConverters._
2626
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
27-
import scala.reflect.ClassTag
2827
import scala.util.control.NonFatal
2928

3029
import org.apache.spark.internal.Logging
3130
import org.apache.spark.sql._
32-
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
31+
import org.apache.spark.sql.catalyst.encoders.encoderFor
3332
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
3433
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
34+
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
3535
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
36-
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
3736
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow}
3837
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
39-
import org.apache.spark.sql.streaming.{OutputMode, Trigger}
38+
import org.apache.spark.sql.streaming.OutputMode
4039
import org.apache.spark.sql.types.StructType
4140
import org.apache.spark.util.Utils
4241

43-
4442
object MemoryStream {
4543
protected val currentBlockId = new AtomicInteger(0)
4644
protected val memoryStreamId = new AtomicInteger(0)
@@ -307,7 +305,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink
307305
case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode {
308306
def this(sink: MemorySink) = this(sink, sink.schema.toAttributes)
309307

310-
private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum
308+
private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes)
311309

312310
override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
313311
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
2727
import org.apache.spark.sql.Row
2828
import org.apache.spark.sql.catalyst.expressions.Attribute
2929
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
30+
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
3031
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update}
3132
import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink}
3233
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport}
@@ -182,7 +183,7 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode)
182183
* Used to query the data that has been written into a [[MemorySinkV2]].
183184
*/
184185
case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode {
185-
private val sizePerRow = output.map(_.dataType.defaultSize).sum
186+
private val sizePerRow = EstimationUtils.getSizePerRow(output)
186187

187188
override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
188189
}

sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
5050
}
5151

5252
assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}")
53-
assert(sizes.head === BigInt(96),
53+
assert(sizes.head === BigInt(128),
5454
s"expected exact size 96 for table 'test', got: ${sizes.head}")
5555
}
5656
}

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,11 +220,11 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
220220

221221
sink.addBatch(0, 1 to 3)
222222
plan.invalidateStatsCache()
223-
assert(plan.stats.sizeInBytes === 12)
223+
assert(plan.stats.sizeInBytes === 36)
224224

225225
sink.addBatch(1, 4 to 6)
226226
plan.invalidateStatsCache()
227-
assert(plan.stats.sizeInBytes === 24)
227+
assert(plan.stats.sizeInBytes === 72)
228228
}
229229

230230
ignore("stress test") {

0 commit comments

Comments
 (0)