Skip to content

Conversation

@asolimando
Copy link
Member

@asolimando asolimando commented Feb 17, 2018

What changes were proposed in this pull request?

Added subtree pruning in the translation from LearningNode to Node: a learning node having a single prediction value for all the leaves in the subtree rooted at it is translated into a LeafNode, instead of a (redundant) InternalNode

How was this patch tested?

Added two unit tests under "mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala":

  • test("SPARK-3159 tree model redundancy - classification")
  • test("SPARK-3159 tree model redundancy - regression")

4 existing unit tests relying on the tree structure (existence of a specific redundant subtree) had to be adapted as the tested components in the output tree are now pruned (fixed by adding an extra prune parameter which can be used to disable pruning for testing)

Copy link
Member

@srowen srowen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I have a lot of style quibbles with this, even if the direction is OK.

But the main problem here is that you do one top-down pass, right? This misses opportunities to further prune the tree that you'd get if you did it bottom-up. Multiple top-down passes also do that, but, isn't efficient.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this is scaladoc comment syntax. Use // for inline comments

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: space after if. I think scalastyle will require it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can write things like leftChild.foreach(predBuffer ++= _.leafPredictions), or something about like that

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally wanted to write something similar, but given that there is not "getChildren" method, I should put leftChild and rightChild in a collection, then filter on _.isDefined, and finally go for the foreach. It globally seemed cleaner the version proposed in the PR. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I'm just thinking you use Option.foreach(...) twice. It's a matter of taste I guess. I tend to prefer avoiding the isDefined and get calls for simple one-liners.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, now I get what you mean. In this case though, we keep the if to apply "short-circuiting" as suggested below.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only used in shouldBeLeaf? just collapse it? you can then short-circuit when you know, for example, that the left subtree already has more than 1 distinct value.

Copy link
Member Author

@asolimando asolimando Feb 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

short-circuiting it is indeed possible and might avoid a lot of extra work, I have modified
if (rightChild.isDefined) {
into
if (predBuffer.size <= 1 && rightChild.isDefined) {

I have also replaced shouldBeLeaf by turning leafPredictions into a lazy val and test on it directly in the toNode method.
I think also readability has improved with this inlining.

Copy link
Member

@srowen srowen Feb 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, and short-circuiting involves an if statement anyway, so maybe no value in my previous suggestion about foreach.

Now this saves the whole Set[Double] of predictions when it's never used again. It's probably worth not hogging that memory. The method here doesn't actually need to form the set of all predictions, just decide whether there's more than 1. For example, in the case that there are no distinct values after evaluating children, no point in adding the leaf's prediction to it just to decide it's got 1 unique value.

EDIT: oh, this is just in the LearningNode. I see. Each subtree is evaluated because each is generated by toNode.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For construction we need the distinct predictions as "redundancy" is non-monotonic, because from being true on subtrees, it can become false:
left and right subtrees of a node can be redundant (left predicts only 0, right only 1) but the node itself is not (it can predict 0 or 1, depending on the split). This can be ascertained only but comparing the prediction values, you cannot derive it from whether two subtrees being redundant alone (as they can be redundant w.r.t. different values).

Consider this example:
1
--1.1
----1.1.1 (pred 0)
----1.1.2 (pred 0)
--1.2
----1.2.1 (pred 1)
----1.2.2 (pred 1)

This can be pruned to
1
--1.1 (pred 0)
--1.2 (pred 1)

As 1.1 and 1.2 are both redundant, but 1 is not (and to verify this, you need their prediction values).

It is a different story if we go for storing "isRedundant" (the shouldBeLeaf of the original PR), that can instead be safely used to decide whether to prune.

As you pointed out this extra information is stored only in LearningNode, so I think it doesn't make a significant difference if we "store" the set of predictions or shouldBeLeaf, as these objects will be disposed anyway at the end of the translation, while what matters is to save computation time and multiple explorations (as you originally suggested).

I would summarise it as follows, we either:

  1. store the set of predictions, or
  2. shouldBeLeaf

(1) requires more memory but it doesn't introduce extra variables (less code to understand), (2) has dual pro/con

Does it make sense to you? Which option would you suggest?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isRedundant seems ... redundant here. Just assert(!RandomForestSuite.isRedundant(dt)). Similar comment for other similar constructs around here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, I have removed isRedundant and rnd.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to doc a private method, but you can. But for my taste this is redundant. Just keep the @return doc

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to admit that I have overdone it a bit with documentation here, I have tried to reduce it all over the PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise an example of the kind of comment that seems to just restate the method name.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just: rnd.nextInt(numericMax - numericMin + 1) + numericMin

Copy link
Member Author

@asolimando asolimando Feb 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there an issue being numericMax and numericMin double values?
I have changed it to:
rnd.nextDouble() * (numericMax - numericMin) + numericMin
I think it is in line with what you proposed (modulo the "+ 1" that I don't understand).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, disregard that. I had in mind that they were ints.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method doesn't seem to add enough to be a whole new method with 4 lines of docs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed the documentation (but the return) for all the auxiliary method here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these methods repeating code in the implementation? I'm not following why all these little methods are needed.

Copy link
Member Author

@asolimando asolimando Feb 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are, but unfortunately here we work with (subclasses of) Node class (composing the decision tree), while the other
methods performing the optimization are defined at LearningNode level, objects that are lost when the decision tree is given as output.

For this reason I had to re-implement a similar logic also in the test, that is also less optimized (negligible for the size of decision tree we generate for testing).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose I haven't read this code thoroughly, but rather than write similar logic differently in the test, to test the original logic, why not just construct cases with known correct answers and compare the pruned tree to expected one?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the code for explicitly testing against the expected tree structure.
I have removed all the auxiliary methods, it saved a lot of space actually and the tests are more explicit and clear, thanks.

@asolimando
Copy link
Member Author

asolimando commented Feb 17, 2018

Hello Sean,
here is my understanding of the problem and the main intuition of the proposed solution:
We want to have a tree such that it does not contain any redundant subtree.
A subtree is redundant iff:

  1. it is not composed by a leaf
  2. all the leaves belonging to it agree on the prediction (the leafPredictions has at most a single distinct value).

Given that for any subtree there is exactly one root node, we "store" this information in its root (learning) node.

A redundant subtree can be replaced by a single leaf, predicting the single prediction value reachable via that subtree (by definition).

So, there are basically two cases in our top-down exploration:

  1. The subtree is redundant -> the whole subtree is explored (linearly) to collect the predictions, the node is translated into a LeafNode, and thus the toNode method is never called on its children, so no extra exploration
  2. The subtree is not redundant -> the whole subtree is explored (linearly) to collect the predictions, the node is translated into an InternalNode, and toNode recursively called on its children. However, at this point, the lazy value shouldBeLeaf (private lazy val shouldBeLeaf: Boolean) has been already set in the previous call, so there won't be any further exploration. Each further step down the recursion path will be able to "reason" locally on stored shouldBeLeaf values.

Indeed, I was about to switch from the first top-down version to a bottom-up one for avoiding the repeated explorations, but then I realized that "lazy" could achieve the same result, while keeping the top-down exploration that is, in my opinion, more straightforward to read.

I have updated the PR according to your suggestions, I haven't squashed the commits yet in case there will be further comments.

I will reply to your comments below so I hope I won't miss anything.

Thanks for your help.

Copy link
Member

@srowen srowen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I understand why the nodes keep the set of predictions now that I see how toNode is called on every one of them. This is looking better; still a few suggestions about the code style

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these tests be adapted or are the really irrelevant? want to make sure this isn't leaving some cases uncovered that tests covered before, if possible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry, two out of three of the adapted tests were errouneously not included in the PR, despite referred to in the description, my bad for not realising until now.

I will describe what I did for those tests that I have added to my latest commit:

  • "mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala"
    1. "Multiclass classification stump with 10-ary (ordered) categorical features"
    2. "Multiclass classification tree with 10-ary (ordered) categorical features with just enough bins"

I had to "adapt" the tests as they where checking properties related to the "learning phase" on the DecisionTree itself, but they were checking on nodes that are not there anymore (on those test trees) due to the pruning.

So, mimicking what is done extensively in the existing test suite, I have extracted the needed information using the "internal" methods of RandomForest (e.g., RandomForest.findBestSplits) and check them directly.

So the tests have been moved to ".../ml/tree/impl/RandomForestSuite.scala" and adapted to call the internal methods of the "ml" version of RandomForest/DecisionTree.

Anyway, internally, the mllib version of DecisionTree is anyway calling that of the ml version. By looking at several other tests in ".../ml/tree/impl/RandomForestSuite.scala" ("Avoid aggregation on the last level" test, for instance), I noticed that they were "a porting" as well, so I have followed their style.

To double-check the adaptation I did a debugging pass for comparing the values of the relevant objects in the two cases, and they were identical.

The only removed test was "Use soft prediction for binary classification with ordered categorical features"
in "mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala".

This test was a duplicate of the test having the same name and already located in ".../ml/tree/impl/RandomForestSuite.scala",
and given that DecisionTree is only a wrapper for RandomForest with numTree=1, they where doing exactly the same thing (modulo the call inside the wrapper), and the coverage of the test cases is left untouched by its removal.

"Use soft prediction for binary classification with ordered categorical features" in ".../ml/tree/impl/RandomForestSuite.scala" had to be adapted too for the same reasons (it was checking information in nodes that are now pruned).

PS: I didn't check extensively as it was out of scope for the PR, but I remember that this was not the only case, there might be are redundant tests. If you think it could be a good idea to remove them, I can re-check and open a JIRA ticket with the gathered information.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding this test - it fails now for a silly reason. Because of the data, the tree built winds up with a right node with equal labels of 1.0 and 2.0. It breaks the tie by prediction 1.0, which left node also predicts. You can modify the data generating method to:

  def generateCategoricalDataPointsForMulticlassForOrderedFeatures():
    Array[LabeledPoint] = {
    val arr = new Array[LabeledPoint](3000)
    for (i <- 0 until 3000) {
      if (i < 1001) {
        arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
      } else if (i < 2000) {
        arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
      } else {
        arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, 2.0))
      }
    }
    arr
  }

so that 2.0 will be predicted. I slightly prefer this, assuming all other tests pass (I checked some of the suites). The less stuff we can move around that is mostly unrelated to this change, the better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These 1-line single-use methods strike me as overhead you don't need; I don't feel strongly about it but inlining it would be at least as readable IMHO. This could also be rnd.nextInt(2)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed generateBinaryValue and used nextInt in place of "nextBoolean + if".

I have also in-lined the label generation as it was a similar case.

However, I would keep generateNumericValue as a separate method, as I am afraid it would harm readability.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array.fill(numBinary)(generateBinaryValue(rnd)) is simpler

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated accordingly the generation of both binary and numeric values, it is indeed simpler this way

Copy link
Member

@srowen srowen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I think this is looking fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW is the logic here that you don't need to add the node's prediction if it's not a leaf, because its prediction must be contained in one of its children, and those predictions were already added?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is correct.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think Map(a -> b) syntax is clearer than Map((a, b)) syntax. (You use the former a few lines later too.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the syntax, not only in the tests I have moved here myself, but everywhere in "ml/tree/impl/RandomForestSuite.scala", in order to have a uniform style at least throughout this file.

@srowen
Copy link
Member

srowen commented Feb 19, 2018

Another one that @jkbradley or @MLnick might want to look at, but seems like a nice win.

@SparkQA
Copy link

SparkQA commented Feb 19, 2018

Test build #4100 has finished for PR 20632 at commit fd3b5df.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@asolimando
Copy link
Member Author

Given that we are converging I have squashed the commits into a single one.

My local mvn scalastyle:check was passing (as well as the check done via the Scala plugin for IntelliiJ) while the one triggered in Jenkins failed.

I have fixed the specific issue, but there might be more, any suggestion on how to reproduce the same setup locally in order to be sure there won't be additional issues on scalastyle side?

@srowen
Copy link
Member

srowen commented Feb 20, 2018

Run ./dev/scalastyle to replicate what Jenkins will do.

@SparkQA
Copy link

SparkQA commented Feb 21, 2018

Test build #4103 has finished for PR 20632 at commit eeb3cc9.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will store a potentially very large collection at each node. For deep regression trees the storage cost could be quite large. We can accomplish the same thing without storing them:

  def toNode: Node = {

    // convert to an inner node only when:
    //  -) the node is not a leaf, and
    //  -) the subtree rooted at this node cannot be replaced by a single leaf
    //     (i.e., there at least two different leaf predictions appear in the subtree)
    if (!isLeafNode) {
      assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null,
        "Unknown error during Decision Tree learning.  Could not convert LearningNode to Node.")
      (leftChild.get.toNode, rightChild.get.toNode) match {
        case (l: LeafNode, r: LeafNode) if l.prediction == r.prediction =>
          new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator)
        case (l, r) =>
          new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
            l, r, split.get, stats.impurityCalculator)
      }
    } else {
      if (stats.valid) {
        new LeafNode(stats.impurityCalculator.predict, stats.impurity,
          stats.impurityCalculator)
      } else {
        // Here we want to keep same behavior with the old mllib.DecisionTreeModel
        new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
      }
    }
  }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only stored during training though. I agree it could turn into a problem, but, wondered how many distinct predictions there could be? in the case of regression, maybe a lot, hm.

Doesn't this only prune cases where a node has two leaf-node children with the same prediction?

We should be able to prune much more than that. You could do so in multiple passes. The other way is roughly what's done here, to remember the predictions in each learning node.

Hm, yeah now I'm wondering about the regression case here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What other cases are there that I'm missing? For one thing, the current tests pass with this change.

This will prune a subtree of arbitrary depth that has a single distinct prediction in its leaf nodes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imagine a full binary tree of 2 levels (7 nodes) with all identical predictions. It can be pruned down to one node, the root. But I think this logic would only prune the lowest level of nodes. Right? you're matching on nodes that have LeafNode children only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a recursive method. It essentially does this:

def prune(node: Node): Node = {
  val left = prune(node.left)
  val right = prune(node.right)
  if (left.isLeaf && right.isLeaf && left.prediction == right.prediction) {
    new LeafNode
  } else {
    new InternalNode
  }
}

It starts pruning at the bottom and then works its way up.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, you call toNode in here. My doubt there was that toNode gets called on each node to turn learning nodes into final ones. Wouldn't this recompute the subtrees over and over? I might have misread how this part works though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variant proposed by @sethah looks complete to me, and it was indeed not necessary to store the predictions.

If I got right the counterexample of @srowen, the new code would behave on it in this way:

(input tree)
1: InnerNode (not built yet)
--1.1: InnerNode (not built yet)
----1.1.1 (label x): LeafNode (not built yet)
----1.1.2 (label x): LeafNode (not built yet)
--1.2: InnerNode (not built yet)
----1.2.1 (label x): LeafNode (not built yet)
----1.2.2 (label x): LeafNode (not built yet)

toNode on 1.1.1 and 1.1.2 builds two LeafNode, so when recursion come back to the context of 1.1, we have the following situation:

1: InnerNode (not built yet)
--1.1: InnerNode (not built yet)
----1.1.1 (label x): LeafNode (not built yet)
----1.1.2 (label x): LeafNode (not built yet)
--1.2: InnerNode (not built yet)
----1.2.1 (label x): LeafNode (not built yet)
----1.2.2 (label x): LeafNode (not built yet)

Therefore, the first case is triggered, and 1.1 becomes a LeafNode, giving us this situation:

1: InnerNode (not built yet)
--1.1 (label x): LeafNode
--1.2: InnerNode (not built yet)
----1.2.1 (label x): LeafNode (not built yet)
----1.2.2 (label x): LeafNode (not built yet)

Then, toNode is called on 1.2, and similarly to what happened to subtree rooted at 1.1 we have a match on the first case, resulting in this situation:

1: InnerNode (not built yet)
--1.1 (label x): LeafNode
--1.2 (label x): LeafNode

Now toNode on 1.1 and 1.2 has built two LeafNode, and the first case matches also in the context of node 1, that is finally translated into a single LeafNode, as expected.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The recursion pattern is the same as before, since it still just calls toNode on left and right when they are defined.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I see how it works now. toNode is called recursively once from the root, not serially on each node, and that's how it always worked. Right, this is much simpler then.

Copy link
Contributor

@sethah sethah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still haven't looked at the actual tests for this patch. Will get to it once the other suggestions are incorporated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this a private val seed = 42 member of RandomForestSuite

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't the purpose of the test, I'd get rid of it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is true but I need to call these internal methods to initialise the structure correctly, including rootNode.

I have removed the only lines that did not look necessary to me:
assert(topNode.isLeaf === false)
assert(topNode.stats === null)

What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a fan of just calling foreach like:

topNode.split.foreach { split =>
    assert(split.isInstanceOf[CategoricalSplit])
    assert(split.toOld.categories === Array(1.0))
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: there's an extra space in there. Also, you should either remove rootNodeSplit or use it everywhere

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding this test - it fails now for a silly reason. Because of the data, the tree built winds up with a right node with equal labels of 1.0 and 2.0. It breaks the tie by prediction 1.0, which left node also predicts. You can modify the data generating method to:

  def generateCategoricalDataPointsForMulticlassForOrderedFeatures():
    Array[LabeledPoint] = {
    val arr = new Array[LabeledPoint](3000)
    for (i <- 0 until 3000) {
      if (i < 1001) {
        arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
      } else if (i < 2000) {
        arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
      } else {
        arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, 2.0))
      }
    }
    arr
  }

so that 2.0 will be predicted. I slightly prefer this, assuming all other tests pass (I checked some of the suites). The less stuff we can move around that is mostly unrelated to this change, the better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here again. You can fix this by inverting the labels. Probably an easier fix than moving and re-writing the test.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true , I have modified the input data for both tests as suggested, and "moved back" the two tests from .../ml/tree/impl/RandomForestSuite.scala to .../mllib/tree/DecisionTreeSuite.scala where they originally were. The whole suite of tests for mllib passes.

As a recap, 2 tests have been adapted by slightly changing the input data:

  • "Multiclass classification stump with 10-ary (ordered) categorical features"
  • "do not choose split that does not satisfy min instance per node requirements"

"Use soft prediction for binary classification with ordered categorical features" was present in two files:

  1. .../ml/classification/DecisionTreeClassifierSuite.scala
  2. .../ml/tree/impl/RandomForestSuite.scala

The one in 1. has been removed because it had to be adapted and it was redundant, while the one in 2. has been adapted following the same principle of other tests in that file such as "Avoid aggregation on the last level" test, for instance".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are fine, but I slightly prefer leaving stuff like this out. These aren't strictly style violations, and it distracts reviewers from the actual changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was at my request, because we made that change in some new code, and then made a little sense to improve this also for consistency.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two tests previously moved here have now been moved back, there is still
"Use soft prediction for binary classification with ordered categorical features" to which I have applied @srowen 's comment, so the consistency argument still holds (even if weakened a bit).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method returns unit. Doesn't make sense to assign its output.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a fan of just calling foreach like:

topNode.split.foreach { split =>
    assert(split.isInstanceOf[CategoricalSplit])
    assert(split.toOld.categories === Array(1.0))
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't mind just removing this change. What constitutes a leaf node is now fuzzy, and if you just inline it the one place it's used there is no confusion. At any rate, you don't need the parentheses after method name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment seems out of place now. You might just say // when both children make the same prediction, collapse into single leaf or something similar below the first case statement.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the type of thing that will puzzle someone down the line. I'm ok with it, though. 😝

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I have added a comment to explain the change:

[SPARK-3159] 1000 instead of 1001 to adapt "Multiclass classification stump with 10-ary (ordered) categorical features" test (different predictions prevent subtree pruning)

Could this be helpful?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, trying to understand these tests. From what I can tell, you've written a data generator that generates random points, and, somewhat by chance, generates redundant tree nodes if the tree is not pruned. Your relying on the random seed to give you a tree which should have exactly 12 descendants after pruning.

I think these may be overly complicated. IMO, we just need to test the situation that causes this by creating simple dataset that can improve the impurity by splitting, but which does not change the prediction. For example, for the Gini impurity you might have the following:

    val data = Array(
      LabeledPoint(0.0, Vectors.dense(1.0)),
      LabeledPoint(0.0, Vectors.dense(1.0)),
      LabeledPoint(0.0, Vectors.dense(1.0)),
      LabeledPoint(0.0, Vectors.dense(0.0)),
      LabeledPoint(0.0, Vectors.dense(0.0)),
      LabeledPoint(1.0, Vectors.dense(0.0))
    )
    val rdd = sc.parallelize(data)
    val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4,
      numClasses = 0, maxBins = 32)
    val tree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
      seed = 42, instr = None).head
    assert(tree.rootNode.numDescendants === 0)

Before this patch, you'd get a tree with two descendants. Actually, I think it might be nice to introduce a prune parameter just for testing, then you can do the following:

    val data = Array(
      LabeledPoint(0.0, Vectors.dense(1.0)),
      LabeledPoint(0.0, Vectors.dense(1.0)),
      LabeledPoint(0.0, Vectors.dense(1.0)),
      LabeledPoint(0.0, Vectors.dense(0.0)),
      LabeledPoint(0.0, Vectors.dense(0.0)),
      LabeledPoint(1.0, Vectors.dense(0.0))
    )
    val rdd = sc.parallelize(data)
    val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4,
      numClasses = 0, maxBins = 32)
    val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
      seed = 42, prune = true, instr = None).head
    val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
      seed = 42, prune = false, instr = None).head
    assert(prunedTree.rootNode.numDescendants === 0)
    assert(unprunedTree.rootNode.numDescendants === 2)

That's a sanity check that the patch has actually made the difference. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tree tests are already so long and complicated that I think it's important to simplify where possible. These tests are useful as they are, but it probably won't be obvious why/how they work to future devs. Also, if we can avoid adding data generation code, that would be nice (there's already tons of code like that laying around the test suites).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree on the whole line, and here are the modifications I have made to align to the comments:

  • I have simplified the tests as much as possible (minimal input data, resulting into minimal trees)
  • I have removed the data generator methods (now the input data are explicitly given)
  • Conditions to be tested are simple (just the number of nodes as you suggested), and they test the pruned vs unpruned version
  • I have added the extra debugging parameter prune (to RandomForest.run(...) and Node.toNode())
  • Now that we have prune, I have reverted back the adaptation of "Use soft prediction for binary classification with ordered categorical features", because now it is sufficient to add prune = false to the invocation of RandomForest.run(...))

Copy link
Contributor

@sethah sethah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is looking better now. Thanks for all the updates! Left a few more small comments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you just overload the method then you don't need to change the existing function calls.

def toNode: Node = toNode(prune = true)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think leaving this out is fine. The tests will fail if it gets changed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to test binary and multiclass separately?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, for the original (and more involved!) version of the "fix" it was less evident that the two cases would behave identically (due to explicit computation of the predictions for a whole subtree), but I understand that with the current version a single test can be sufficient.

In case, which one of the two (binary/multiclass) classification would you keep?
Binary is simpler and more immediate to figure out (also by looking at the data points).
But I am afraid that dropping the test for mutliclass might cause doubts on the tests coverage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to just leave the binary classification test.

@srowen
Copy link
Member

srowen commented Feb 27, 2018

PS @sethah feel free to merge this one when you think it's ready.

Copy link
Contributor

@sethah sethah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just minor stuff.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to just leave the binary classification test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about a comment here: prune: Boolean = true, // exposed for testing only, real trees are always pruned

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, I'm not sure the comment is useful since it just explains what the code does. I vote either no comment or we explain why this happens, i.e. you can improve impurity without changing the prediction.

@sethah
Copy link
Contributor

sethah commented Feb 28, 2018

Jenkins test this please.

@sethah
Copy link
Contributor

sethah commented Feb 28, 2018

@srowen Do we need you to trigger the tests? I'm not sure why they haven't been run...

@SparkQA
Copy link

SparkQA commented Feb 28, 2018

Test build #4136 has finished for PR 20632 at commit d02ff2a.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@sethah
Copy link
Contributor

sethah commented Feb 28, 2018

@asolimando Can you change the title to include [ML] and also shorten it. Maybe just: Add decision tree pruning

@asolimando asolimando changed the title [SPARK-3159] added subtree pruning in the translation from LearningNode to Node, added unit tests for tree redundancy and adapted existing ones that were affected [SPARK-3159][ML] Add decision tree pruning Feb 28, 2018
@asolimando
Copy link
Member Author

@sethah I have shortened the name as suggested and squashed the commits into a single one, let me know it that's ok!

@sethah
Copy link
Contributor

sethah commented Feb 28, 2018

Squashing makes it impossible to review the history of the code review, so I don't think it's a good idea. It's fine for now.

This LGTM. Let's see if @srowen or @jkbradley have any thoughts.

@sethah
Copy link
Contributor

sethah commented Feb 28, 2018

Jenkins retest this please.

@sethah
Copy link
Contributor

sethah commented Mar 2, 2018

Jenkins test this please.


val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
seed = 42, instr = None, prune = false).head

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind adding a check in both tests to make sure that the count of all the leaf nodes sums to the total count (i.e. 6)? That way we make sure we don't lose information when merging the leaves? You can do it via leafNode.impurityStats.count.

@sethah
Copy link
Contributor

sethah commented Mar 2, 2018

@asolimando I thought of one more thing on the tests that I'd like to have. Other than that I think this is ready.

@srowen For some reason the tests won't run... Do you have any insight into this?

@srowen
Copy link
Member

srowen commented Mar 2, 2018

Jenkins add to whitelist

@srowen
Copy link
Member

srowen commented Mar 2, 2018

@sethah not sure -- I just use https://spark-prs.appspot.com/ in most cases to trigger.

@SparkQA
Copy link

SparkQA commented Mar 2, 2018

Test build #4138 has finished for PR 20632 at commit 624321a.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 2, 2018

Test build #87898 has finished for PR 20632 at commit 624321a.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@asolimando
Copy link
Member Author

Thanks for the suggestion @sethah, I have updated the PR with the extra check (both tests).

}

@tailrec
private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to enclose the function body in curly braces

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I have added them

@SparkQA
Copy link

SparkQA commented Mar 2, 2018

Test build #87908 has finished for PR 20632 at commit 0183678.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 2, 2018

Test build #87910 has finished for PR 20632 at commit 3dfe86c.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@asfgit asfgit closed this in 9e26473 Mar 3, 2018
@sethah
Copy link
Contributor

sethah commented Mar 3, 2018

Merged with master. Thanks!

@asolimando
Copy link
Member Author

Thanks to you @srowen and @sethah for all your feedbacks which sensibly improved the PR!

@tengpeng
Copy link
Contributor

tengpeng commented Apr 9, 2018

(A note to me & future readers) It seems this is actually for [SPARK-3159] Check for reducible DecisionTree, rather than SPARK-3155
Support DecisionTree pruning. The title is confusing to me because decisionTree pruning is a well-defined term which refers different process.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants