diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 11d682940239..32045ff5a521 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -251,6 +251,8 @@ case class Except( override protected lazy val validConstraints: ExpressionSet = leftConstraints + override def maxRows: Option[Long] = left.maxRows + override protected def withNewChildrenInternal( newLeft: LogicalPlan, newRight: LogicalPlan): Except = copy(left = newLeft, right = newRight) } @@ -758,6 +760,9 @@ case class Sort( child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows + override def maxRowsPerPartition: Option[Long] = { + if (global) maxRows else child.maxRowsPerPartition + } override def outputOrdering: Seq[SortOrder] = order final override val nodePatterns: Seq[TreePattern] = Seq(SORT) override protected def withNewChildInternal(newChild: LogicalPlan): Sort = copy(child = newChild) @@ -1163,6 +1168,19 @@ case class Expand( override lazy val references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) + override def maxRows: Option[Long] = child.maxRows match { + case Some(m) => + val n = BigInt(m) * projections.length + if (n.isValidLong) Some(n.toLong) else None + case _ => None + } + override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition match { + case Some(m) => + val n = BigInt(m) * projections.length + if (n.isValidLong) Some(n.toLong) else None + case _ => maxRows + } + override def metadataOutput: Seq[Attribute] = Nil override def producedAttributes: AttributeSet = AttributeSet(output diff child.output) @@ -1432,11 +1450,15 @@ case class Sample( s"Sampling fraction ($fraction) must be on interval [0, 1] without replacement") } + // when withReplacement is true, PoissonSampler is applied in SampleExec, + // which may output more rows than child. override def maxRows: Option[Long] = { - // when withReplacement is true, PoissonSampler is applied in SampleExec, - // which may output more rows than child.maxRows. if (withReplacement) None else child.maxRows } + override def maxRowsPerPartition: Option[Long] = { + if (withReplacement) None else child.maxRowsPerPartition + } + override def output: Seq[Attribute] = child.output override protected def withNewChildInternal(newChild: LogicalPlan): Sample = @@ -1626,6 +1648,8 @@ case class CollectMetrics( name.nonEmpty && metrics.nonEmpty && metrics.forall(_.resolved) && childrenResolved } + override def maxRows: Option[Long] = child.maxRows + override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition override def output: Seq[Attribute] = child.output override protected def withNewChildInternal(newChild: LogicalPlan): CollectMetrics = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 5dac35a33a6b..1d533e9d0d41 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -113,4 +113,30 @@ class LogicalPlanSuite extends SparkFunSuite { assert(query.maxRows.isEmpty) assert(query.maxRowsPerPartition.isEmpty) } + + test("SPARK-37961: add maxRows/maxRowsPerPartition for some logical nodes") { + val range = Range(0, 100, 1, 3) + assert(range.maxRows === Some(100)) + assert(range.maxRowsPerPartition === Some(34)) + + val sort = Sort(Seq('id.asc), false, range) + assert(sort.maxRows === Some(100)) + assert(sort.maxRowsPerPartition === Some(34)) + val sort2 = Sort(Seq('id.asc), true, range) + assert(sort2.maxRows === Some(100)) + assert(sort2.maxRowsPerPartition === Some(100)) + + val c1 = Literal(1).as('a).toAttribute.newInstance().withNullability(true) + val c2 = Literal(2).as('b).toAttribute.newInstance().withNullability(true) + val expand = Expand( + Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))), + Seq(c1, c2), + sort.select('id as 'a, 'id + 1 as 'b)) + assert(expand.maxRows === Some(200)) + assert(expand.maxRowsPerPartition === Some(68)) + + val sample = Sample(0.1, 0.9, false, 42, expand) + assert(sample.maxRows === Some(200)) + assert(sample.maxRowsPerPartition === Some(68)) + } }