Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
* @return the stream of alternatives
*/
def multiTransformDown(
rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = {
multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
}

Expand All @@ -639,10 +639,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
* as a lazy `Stream` to be able to limit the number of alternatives generated at the caller side
* as needed.
*
* The rule should not apply or can return a one element stream of original node to indicate that
* The purpose of this function to access the returned alternatives by the rule only if they are
* needed so the rule can return a `Stream` whose elements are also lazily calculated.
* E.g. `multiTransform*` calls can be nested with the help of
* `MultiTransform.generateCartesianProduct()`.
*
* The rule should not apply or can return a one element `Seq` of original node to indicate that
* the original node without any transformation is a valid alternative.
*
* The rule can return `Stream.empty` to indicate that the original node should be pruned. In this
* The rule can return `Seq.empty` to indicate that the original node should be pruned. In this
* case `multiTransform()` returns an empty `Stream`.
*
* Please consider the following examples of `input.multiTransformDown(rule)`:
Expand All @@ -652,19 +657,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
*
* 1.
* We have a simple rule:
* `a` => `Stream(1, 2)`
* `b` => `Stream(10, 20)`
* `Add(a, b)` => `Stream(11, 12, 21, 22)`
* `a` => `Seq(1, 2)`
* `b` => `Seq(10, 20)`
* `Add(a, b)` => `Seq(11, 12, 21, 22)`
*
* The output is:
* `Stream(11, 12, 21, 22)`
*
* 2.
* In the previous example if we want to generate alternatives of `a` and `b` too then we need to
* explicitly add the original `Add(a, b)` expression to the rule:
* `a` => `Stream(1, 2)`
* `b` => `Stream(10, 20)`
* `Add(a, b)` => `Stream(11, 12, 21, 22, Add(a, b))`
* `a` => `Seq(1, 2)`
* `b` => `Seq(10, 20)`
* `Add(a, b)` => `Seq(11, 12, 21, 22, Add(a, b))`
*
* The output is:
* `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`
Expand All @@ -683,25 +688,25 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
def multiTransformDownWithPruning(
cond: TreePatternBits => Boolean,
ruleId: RuleId = UnknownRuleId
)(rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
)(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = {
if (!cond.apply(this) || isRuleIneffective(ruleId)) {
return Stream(this)
}

// We could return `Stream(this)` if the `rule` doesn't apply and handle both
// We could return `Seq(this)` if the `rule` doesn't apply and handle both
// - the doesn't apply
// - and the rule returns a one element `Stream(originalNode)`
// cases together. But, unfortunately it doesn't seem like there is a way to match on a one
// element stream without eagerly computing the tail head. So this contradicts with the purpose
// of only taking the necessary elements from the alternatives. I.e. the
// "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail.
// - and the rule returns a one element `Seq(originalNode)`
// cases together. The returned `Seq` can be a `Stream` and unfortunately it doesn't seem like
// there is a way to match on a one element stream without eagerly computing the tail's head.
// This contradicts with the purpose of only taking the necessary elements from the
// alternatives. I.e. the "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail.
// Please note that this behaviour has a downside as well that we can only mark the rule on the
// original node ineffective if the rule didn't match.
var ruleApplied = true
val afterRules = CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, (_: BaseType) => {
ruleApplied = false
Stream.empty
Seq.empty
})
}

Expand All @@ -716,7 +721,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
}
} else {
// If the rule was applied then use the returned alternatives
afterRules.map { afterRule =>
afterRules.toStream.map { afterRule =>
if (this fastEquals afterRule) {
this
} else {
Expand All @@ -728,7 +733,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre

afterRulesStream.flatMap { afterRule =>
if (afterRule.containsChild.nonEmpty) {
generateChildrenSeq(
MultiTransform.generateCartesianProduct(
afterRule.children.map(_.multiTransformDownWithPruning(cond, ruleId)(rule)))
.map(afterRule.withNewChildren)
} else {
Expand All @@ -737,15 +742,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
}
}

private def generateChildrenSeq[T](childrenStreams: Seq[Stream[T]]): Stream[Seq[T]] = {
childrenStreams.foldRight(Stream(Seq.empty[T]))((childrenStream, childrenSeqStream) =>
for {
childrenSeq <- childrenSeqStream
child <- childrenStream
} yield child +: childrenSeq
)
}

/**
* Returns a copy of this node where `f` has been applied to all the nodes in `children`.
*/
Expand Down Expand Up @@ -1368,3 +1364,21 @@ trait QuaternaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>

protected def withNewChildrenInternal(newFirst: T, newSecond: T, newThird: T, newFourth: T): T
}

object MultiTransform {

/**
* Returns the stream of `Seq` elements by generating the cartesian product of sequences.
*
* @param elementSeqs a list of sequences to build the cartesian product from
* @return the stream of generated `Seq` elements
*/
def generateCartesianProduct[T](elementSeqs: Seq[Seq[T]]): Stream[Seq[T]] = {
elementSeqs.foldRight(Stream(Seq.empty[T]))((elements, elementTails) =>
for {
elementTail <- elementTails
element <- elements
} yield element +: elementTail
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -987,10 +987,10 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
test("multiTransformDown generates all alternatives") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDown {
case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30))
case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3))
case StringLiteral("b") => Seq(Literal(10), Literal(20), Literal(30))
case Add(StringLiteral("c"), StringLiteral("d"), _) =>
Stream(Literal(100), Literal(200), Literal(300))
Seq(Literal(100), Literal(200), Literal(300))
}
val expected = for {
cd <- Seq(Literal(100), Literal(200), Literal(300))
Expand All @@ -1003,7 +1003,7 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
test("multiTransformDown is lazy") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDown {
case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3))
case StringLiteral("b") => newErrorAfterStream(Literal(10))
case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100))
}
Expand All @@ -1017,8 +1017,8 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
}

val transformed2 = e.multiTransformDown {
case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30))
case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3))
case StringLiteral("b") => Seq(Literal(10), Literal(20), Literal(30))
case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100))
}
val expected2 = for {
Expand All @@ -1035,10 +1035,9 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
test("multiTransformDown rule return this") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDown {
case s @ StringLiteral("a") => Stream(Literal(1), Literal(2), s)
case s @ StringLiteral("b") => Stream(Literal(10), Literal(20), s)
case a @ Add(StringLiteral("c"), StringLiteral("d"), _) =>
Stream(Literal(100), Literal(200), a)
case s @ StringLiteral("a") => Seq(Literal(1), Literal(2), s)
case s @ StringLiteral("b") => Seq(Literal(10), Literal(20), s)
case a @ Add(StringLiteral("c"), StringLiteral("d"), _) => Seq(Literal(100), Literal(200), a)
}
val expected = for {
cd <- Seq(Literal(100), Literal(200), Add(Literal("c"), Literal("d")))
Expand All @@ -1053,10 +1052,10 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDown {
case a @ Add(StringLiteral("a"), StringLiteral("b"), _) =>
Stream(Literal(11), Literal(12), Literal(21), Literal(22), a)
case StringLiteral("a") => Stream(Literal(1), Literal(2))
case StringLiteral("b") => Stream(Literal(10), Literal(20))
case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream(Literal(100), Literal(200))
Seq(Literal(11), Literal(12), Literal(21), Literal(22), a)
case StringLiteral("a") => Seq(Literal(1), Literal(2))
case StringLiteral("b") => Seq(Literal(10), Literal(20))
case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq(Literal(100), Literal(200))
}
val expected = for {
cd <- Seq(Literal(100), Literal(200))
Expand All @@ -1072,12 +1071,12 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
test("multiTransformDown can prune") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDown {
case StringLiteral("a") => Stream.empty
case StringLiteral("a") => Seq.empty
}
assert(transformed.isEmpty)

val transformed2 = e.multiTransformDown {
case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream.empty
case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq.empty
}
assert(transformed2.isEmpty)
}
Expand Down