Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,134 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
}
}

/**
* Returns alternative copies of this node where `rule` has been recursively applied to it and all
* of its children (pre-order).
*
* @param rule a function used to generate alternatives for a node
* @return the stream of alternatives
*/
def multiTransformDown(
rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
}

/**
* Returns alternative copies of this node where `rule` has been recursively applied to it and all
* of its children (pre-order).
*
* As it is very easy to generate enormous number of alternatives when the input tree is huge or
* when the rule returns many alternatives for many nodes, this function returns the alternatives
* as a lazy `Stream` to be able to limit the number of alternatives generated at the caller side
* as needed.
*
* The rule should not apply or can return a one element stream of original node to indicate that
* the original node without any transformation is a valid alternative.
*
* The rule can return `Stream.empty` to indicate that the original node should be pruned. In this
* case `multiTransform()` returns an empty `Stream`.
*
* Please consider the following examples of `input.multiTransformDown(rule)`:
*
* We have an input expression:
* `Add(a, b)`
*
* 1.
* We have a simple rule:
* `a` => `Stream(1, 2)`
* `b` => `Stream(10, 20)`
* `Add(a, b)` => `Stream(11, 12, 21, 22)`
*
* The output is:
* `Stream(11, 12, 21, 22)`
*
* 2.
* In the previous example if we want to generate alternatives of `a` and `b` too then we need to
* explicitly add the original `Add(a, b)` expression to the rule:
* `a` => `Stream(1, 2)`
* `b` => `Stream(10, 20)`
* `Add(a, b)` => `Stream(11, 12, 21, 22, Add(a, b))`
*
* The output is:
* `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`
*
* @param rule a function used to generate alternatives for a node
* @param cond a Lambda expression to prune tree traversals. If `cond.apply` returns false
* on a TreeNode T, skips processing T and its subtree; otherwise, processes
* T and its subtree recursively.
* @param ruleId is a unique Id for `rule` to prune unnecessary tree traversals. When it is
* UnknownRuleId, no pruning happens. Otherwise, if `rule` (with id `ruleId`)
* has been marked as in effective on a TreeNode T, skips processing T and its
* subtree. Do not pass it if the rule is not purely functional and reads a
* varying initial state for different invocations.
* @return the stream of alternatives
*/
def multiTransformDownWithPruning(
cond: TreePatternBits => Boolean,
ruleId: RuleId = UnknownRuleId
)(rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
if (!cond.apply(this) || isRuleIneffective(ruleId)) {
return Stream(this)
}

// We could return `Stream(this)` if the `rule` doesn't apply and handle both
// - the doesn't apply
// - and the rule returns a one element `Stream(originalNode)`
// cases together. But, unfortunately it doesn't seem like there is a way to match on a one
// element stream without eagerly computing the tail head. So this contradicts with the purpose
// of only taking the necessary elements from the alternatives. I.e. the
// "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail.
// Please note that this behaviour has a downside as well that we can only mark the rule on the
// original node ineffective if the rule didn't match.
var ruleApplied = true
val afterRules = CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, (_: BaseType) => {
ruleApplied = false
Stream.empty
})
}

val afterRulesStream = if (afterRules.isEmpty) {
if (ruleApplied) {
// If the rule returned with empty alternatives then prune
Stream.empty
} else {
// If the rule was not applied then keep the original node
this.markRuleAsIneffective(ruleId)
Stream(this)
}
} else {
// If the rule was applied then use the returned alternatives
afterRules.map { afterRule =>
if (this fastEquals afterRule) {
this
} else {
afterRule.copyTagsFrom(this)
afterRule
}
}
}

afterRulesStream.flatMap { afterRule =>
if (afterRule.containsChild.nonEmpty) {
generateChildrenSeq(
afterRule.children.map(_.multiTransformDownWithPruning(cond, ruleId)(rule)))
.map(afterRule.withNewChildren)
} else {
Stream(afterRule)
}
}
}

private def generateChildrenSeq[T](childrenStreams: Seq[Stream[T]]): Stream[Seq[T]] = {
childrenStreams.foldRight(Stream(Seq.empty[T]))((childrenStream, childrenSeqStream) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we fold from right to left?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to generate alternatives for the first children of an expression first.
E.g. if a + b is the input and a => Stream(a1, a2) and b => Stream(b1, b2) is the rule then I wanted to get the Stream(a1 + b1, a2 + b1, a1 + b2, a2 + b2) output in this order.

for {
childrenSeq <- childrenSeqStream
child <- childrenStream
} yield child +: childrenSeq
)
}

/**
* Returns a copy of this node where `f` has been applied to all the nodes in `children`.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -977,4 +977,108 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
assert(origin.context.summary.isEmpty)
}
}

private def newErrorAfterStream(es: Expression*) = {
es.toStream.append(
throw new NoSuchElementException("Stream should not return more elements")
)
}

test("multiTransformDown generates all alternatives") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDown {
case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30))
case Add(StringLiteral("c"), StringLiteral("d"), _) =>
Stream(Literal(100), Literal(200), Literal(300))
}
val expected = for {
cd <- Seq(Literal(100), Literal(200), Literal(300))
b <- Seq(Literal(10), Literal(20), Literal(30))
a <- Seq(Literal(1), Literal(2), Literal(3))
} yield Add(Add(a, b), cd)
assert(transformed === expected)
}

test("multiTransformDown is lazy") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDown {
case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
case StringLiteral("b") => newErrorAfterStream(Literal(10))
case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100))
}
val expected = for {
a <- Seq(Literal(1), Literal(2), Literal(3))
} yield Add(Add(a, Literal(10)), Literal(100))
// We don't access alternatives for `b` after 10 and for `c` after 100
assert(transformed.take(3) == expected)
intercept[NoSuchElementException] {
transformed.take(3 + 1).toList
}

val transformed2 = e.multiTransformDown {
case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30))
case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100))
}
val expected2 = for {
b <- Seq(Literal(10), Literal(20), Literal(30))
a <- Seq(Literal(1), Literal(2), Literal(3))
} yield Add(Add(a, b), Literal(100))
// We don't access alternatives for `c` after 100
assert(transformed2.take(3 * 3) === expected2)
intercept[NoSuchElementException] {
transformed.take(3 * 3 + 1).toList
}
}

test("multiTransformDown rule return this") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDown {
case s @ StringLiteral("a") => Stream(Literal(1), Literal(2), s)
case s @ StringLiteral("b") => Stream(Literal(10), Literal(20), s)
case a @ Add(StringLiteral("c"), StringLiteral("d"), _) =>
Stream(Literal(100), Literal(200), a)
}
val expected = for {
cd <- Seq(Literal(100), Literal(200), Add(Literal("c"), Literal("d")))
b <- Seq(Literal(10), Literal(20), Literal("b"))
a <- Seq(Literal(1), Literal(2), Literal("a"))
} yield Add(Add(a, b), cd)
assert(transformed == expected)
}

test("multiTransformDown doesn't stop generating alternatives of descendants when non-leaf is " +
"transformed and itself is in the alternatives") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDown {
case a @ Add(StringLiteral("a"), StringLiteral("b"), _) =>
Stream(Literal(11), Literal(12), Literal(21), Literal(22), a)
case StringLiteral("a") => Stream(Literal(1), Literal(2))
case StringLiteral("b") => Stream(Literal(10), Literal(20))
case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream(Literal(100), Literal(200))
}
val expected = for {
cd <- Seq(Literal(100), Literal(200))
ab <- Seq(Literal(11), Literal(12), Literal(21), Literal(22)) ++
(for {
b <- Seq(Literal(10), Literal(20))
a <- Seq(Literal(1), Literal(2))
} yield Add(a, b))
} yield Add(ab, cd)
assert(transformed == expected)
}

test("multiTransformDown can prune") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDown {
case StringLiteral("a") => Stream.empty
}
assert(transformed.isEmpty)

val transformed2 = e.multiTransformDown {
case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream.empty
}
assert(transformed2.isEmpty)
}
}