diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index dc64e5e25605..c8df2086a72c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -626,7 +626,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre * @return the stream of alternatives */ def multiTransformDown( - rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = { + rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = { multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) } @@ -639,10 +639,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre * as a lazy `Stream` to be able to limit the number of alternatives generated at the caller side * as needed. * - * The rule should not apply or can return a one element stream of original node to indicate that + * The purpose of this function to access the returned alternatives by the rule only if they are + * needed so the rule can return a `Stream` whose elements are also lazily calculated. + * E.g. `multiTransform*` calls can be nested with the help of + * `MultiTransform.generateCartesianProduct()`. + * + * The rule should not apply or can return a one element `Seq` of original node to indicate that * the original node without any transformation is a valid alternative. * - * The rule can return `Stream.empty` to indicate that the original node should be pruned. In this + * The rule can return `Seq.empty` to indicate that the original node should be pruned. In this * case `multiTransform()` returns an empty `Stream`. * * Please consider the following examples of `input.multiTransformDown(rule)`: @@ -652,9 +657,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre * * 1. * We have a simple rule: - * `a` => `Stream(1, 2)` - * `b` => `Stream(10, 20)` - * `Add(a, b)` => `Stream(11, 12, 21, 22)` + * `a` => `Seq(1, 2)` + * `b` => `Seq(10, 20)` + * `Add(a, b)` => `Seq(11, 12, 21, 22)` * * The output is: * `Stream(11, 12, 21, 22)` @@ -662,9 +667,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre * 2. * In the previous example if we want to generate alternatives of `a` and `b` too then we need to * explicitly add the original `Add(a, b)` expression to the rule: - * `a` => `Stream(1, 2)` - * `b` => `Stream(10, 20)` - * `Add(a, b)` => `Stream(11, 12, 21, 22, Add(a, b))` + * `a` => `Seq(1, 2)` + * `b` => `Seq(10, 20)` + * `Add(a, b)` => `Seq(11, 12, 21, 22, Add(a, b))` * * The output is: * `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))` @@ -683,25 +688,25 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre def multiTransformDownWithPruning( cond: TreePatternBits => Boolean, ruleId: RuleId = UnknownRuleId - )(rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = { + )(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = { if (!cond.apply(this) || isRuleIneffective(ruleId)) { return Stream(this) } - // We could return `Stream(this)` if the `rule` doesn't apply and handle both + // We could return `Seq(this)` if the `rule` doesn't apply and handle both // - the doesn't apply - // - and the rule returns a one element `Stream(originalNode)` - // cases together. But, unfortunately it doesn't seem like there is a way to match on a one - // element stream without eagerly computing the tail head. So this contradicts with the purpose - // of only taking the necessary elements from the alternatives. I.e. the - // "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail. + // - and the rule returns a one element `Seq(originalNode)` + // cases together. The returned `Seq` can be a `Stream` and unfortunately it doesn't seem like + // there is a way to match on a one element stream without eagerly computing the tail's head. + // This contradicts with the purpose of only taking the necessary elements from the + // alternatives. I.e. the "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail. // Please note that this behaviour has a downside as well that we can only mark the rule on the // original node ineffective if the rule didn't match. var ruleApplied = true val afterRules = CurrentOrigin.withOrigin(origin) { rule.applyOrElse(this, (_: BaseType) => { ruleApplied = false - Stream.empty + Seq.empty }) } @@ -716,7 +721,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre } } else { // If the rule was applied then use the returned alternatives - afterRules.map { afterRule => + afterRules.toStream.map { afterRule => if (this fastEquals afterRule) { this } else { @@ -728,7 +733,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre afterRulesStream.flatMap { afterRule => if (afterRule.containsChild.nonEmpty) { - generateChildrenSeq( + MultiTransform.generateCartesianProduct( afterRule.children.map(_.multiTransformDownWithPruning(cond, ruleId)(rule))) .map(afterRule.withNewChildren) } else { @@ -737,15 +742,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre } } - private def generateChildrenSeq[T](childrenStreams: Seq[Stream[T]]): Stream[Seq[T]] = { - childrenStreams.foldRight(Stream(Seq.empty[T]))((childrenStream, childrenSeqStream) => - for { - childrenSeq <- childrenSeqStream - child <- childrenStream - } yield child +: childrenSeq - ) - } - /** * Returns a copy of this node where `f` has been applied to all the nodes in `children`. */ @@ -1368,3 +1364,21 @@ trait QuaternaryLike[T <: TreeNode[T]] { self: TreeNode[T] => protected def withNewChildrenInternal(newFirst: T, newSecond: T, newThird: T, newFourth: T): T } + +object MultiTransform { + + /** + * Returns the stream of `Seq` elements by generating the cartesian product of sequences. + * + * @param elementSeqs a list of sequences to build the cartesian product from + * @return the stream of generated `Seq` elements + */ + def generateCartesianProduct[T](elementSeqs: Seq[Seq[T]]): Stream[Seq[T]] = { + elementSeqs.foldRight(Stream(Seq.empty[T]))((elements, elementTails) => + for { + elementTail <- elementTails + element <- elements + } yield element +: elementTail + ) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index ac28917675e6..e4adf59b392b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -987,10 +987,10 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { test("multiTransformDown generates all alternatives") { val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) val transformed = e.multiTransformDown { - case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) - case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30)) + case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3)) + case StringLiteral("b") => Seq(Literal(10), Literal(20), Literal(30)) case Add(StringLiteral("c"), StringLiteral("d"), _) => - Stream(Literal(100), Literal(200), Literal(300)) + Seq(Literal(100), Literal(200), Literal(300)) } val expected = for { cd <- Seq(Literal(100), Literal(200), Literal(300)) @@ -1003,7 +1003,7 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { test("multiTransformDown is lazy") { val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) val transformed = e.multiTransformDown { - case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) + case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3)) case StringLiteral("b") => newErrorAfterStream(Literal(10)) case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100)) } @@ -1017,8 +1017,8 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { } val transformed2 = e.multiTransformDown { - case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) - case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30)) + case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3)) + case StringLiteral("b") => Seq(Literal(10), Literal(20), Literal(30)) case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100)) } val expected2 = for { @@ -1035,10 +1035,9 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { test("multiTransformDown rule return this") { val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) val transformed = e.multiTransformDown { - case s @ StringLiteral("a") => Stream(Literal(1), Literal(2), s) - case s @ StringLiteral("b") => Stream(Literal(10), Literal(20), s) - case a @ Add(StringLiteral("c"), StringLiteral("d"), _) => - Stream(Literal(100), Literal(200), a) + case s @ StringLiteral("a") => Seq(Literal(1), Literal(2), s) + case s @ StringLiteral("b") => Seq(Literal(10), Literal(20), s) + case a @ Add(StringLiteral("c"), StringLiteral("d"), _) => Seq(Literal(100), Literal(200), a) } val expected = for { cd <- Seq(Literal(100), Literal(200), Add(Literal("c"), Literal("d"))) @@ -1053,10 +1052,10 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) val transformed = e.multiTransformDown { case a @ Add(StringLiteral("a"), StringLiteral("b"), _) => - Stream(Literal(11), Literal(12), Literal(21), Literal(22), a) - case StringLiteral("a") => Stream(Literal(1), Literal(2)) - case StringLiteral("b") => Stream(Literal(10), Literal(20)) - case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream(Literal(100), Literal(200)) + Seq(Literal(11), Literal(12), Literal(21), Literal(22), a) + case StringLiteral("a") => Seq(Literal(1), Literal(2)) + case StringLiteral("b") => Seq(Literal(10), Literal(20)) + case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq(Literal(100), Literal(200)) } val expected = for { cd <- Seq(Literal(100), Literal(200)) @@ -1072,12 +1071,12 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { test("multiTransformDown can prune") { val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) val transformed = e.multiTransformDown { - case StringLiteral("a") => Stream.empty + case StringLiteral("a") => Seq.empty } assert(transformed.isEmpty) val transformed2 = e.multiTransformDown { - case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream.empty + case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq.empty } assert(transformed2.isEmpty) }