Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ case class Except(

override protected lazy val validConstraints: ExpressionSet = leftConstraints

override def maxRows: Option[Long] = left.maxRows

override protected def withNewChildrenInternal(
newLeft: LogicalPlan, newRight: LogicalPlan): Except = copy(left = newLeft, right = newRight)
}
Expand Down Expand Up @@ -758,6 +760,9 @@ case class Sort(
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def maxRows: Option[Long] = child.maxRows
override def maxRowsPerPartition: Option[Long] = {
if (global) maxRows else child.maxRowsPerPartition
}
override def outputOrdering: Seq[SortOrder] = order
final override val nodePatterns: Seq[TreePattern] = Seq(SORT)
override protected def withNewChildInternal(newChild: LogicalPlan): Sort = copy(child = newChild)
Expand Down Expand Up @@ -1163,6 +1168,19 @@ case class Expand(
override lazy val references: AttributeSet =
AttributeSet(projections.flatten.flatMap(_.references))

override def maxRows: Option[Long] = child.maxRows match {
case Some(m) =>
val n = BigInt(m) * projections.length
if (n.isValidLong) Some(n.toLong) else None
case _ => None
}
override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition match {
case Some(m) =>
val n = BigInt(m) * projections.length
if (n.isValidLong) Some(n.toLong) else None
case _ => maxRows
}

override def metadataOutput: Seq[Attribute] = Nil

override def producedAttributes: AttributeSet = AttributeSet(output diff child.output)
Expand Down Expand Up @@ -1432,11 +1450,15 @@ case class Sample(
s"Sampling fraction ($fraction) must be on interval [0, 1] without replacement")
}

// when withReplacement is true, PoissonSampler is applied in SampleExec,
// which may output more rows than child.
override def maxRows: Option[Long] = {
// when withReplacement is true, PoissonSampler is applied in SampleExec,
// which may output more rows than child.maxRows.
if (withReplacement) None else child.maxRows
}
override def maxRowsPerPartition: Option[Long] = {
if (withReplacement) None else child.maxRowsPerPartition
}

override def output: Seq[Attribute] = child.output

override protected def withNewChildInternal(newChild: LogicalPlan): Sample =
Expand Down Expand Up @@ -1626,6 +1648,8 @@ case class CollectMetrics(
name.nonEmpty && metrics.nonEmpty && metrics.forall(_.resolved) && childrenResolved
}

override def maxRows: Option[Long] = child.maxRows
override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition
override def output: Seq[Attribute] = child.output

override protected def withNewChildInternal(newChild: LogicalPlan): CollectMetrics =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,30 @@ class LogicalPlanSuite extends SparkFunSuite {
assert(query.maxRows.isEmpty)
assert(query.maxRowsPerPartition.isEmpty)
}

test("SPARK-37961: add maxRows/maxRowsPerPartition for some logical nodes") {
val range = Range(0, 100, 1, 3)
assert(range.maxRows === Some(100))
assert(range.maxRowsPerPartition === Some(34))

val sort = Sort(Seq('id.asc), false, range)
assert(sort.maxRows === Some(100))
assert(sort.maxRowsPerPartition === Some(34))
val sort2 = Sort(Seq('id.asc), true, range)
assert(sort2.maxRows === Some(100))
assert(sort2.maxRowsPerPartition === Some(100))

val c1 = Literal(1).as('a).toAttribute.newInstance().withNullability(true)
val c2 = Literal(2).as('b).toAttribute.newInstance().withNullability(true)
val expand = Expand(
Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))),
Seq(c1, c2),
sort.select('id as 'a, 'id + 1 as 'b))
assert(expand.maxRows === Some(200))
assert(expand.maxRowsPerPartition === Some(68))

val sample = Sample(0.1, 0.9, false, 42, expand)
assert(sample.maxRows === Some(200))
assert(sample.maxRowsPerPartition === Some(68))
}
}