-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-3042] [mllib] DecisionTree Filter top-down instead of bottom-up #1975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 25 commits
a95bc22
511ec85
bcf874a
f61e9d2
3211f02
0f676e2
b2ed1f3
b914f3b
c1565a5
a87e08f
8464a6e
e66f1b1
d036089
430d782
356daba
26d10dd
2d2aaaf
6b5651e
5f2dec2
f40381c
797f68a
931a3a7
6a38f48
db0d773
ac0b9f8
3726d20
a0ed0da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.mllib.tree.impl | ||
|
|
||
| import scala.collection.mutable | ||
|
|
||
| import org.apache.spark.mllib.regression.LabeledPoint | ||
| import org.apache.spark.mllib.tree.configuration.Algo._ | ||
| import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ | ||
| import org.apache.spark.mllib.tree.configuration.Strategy | ||
| import org.apache.spark.mllib.tree.impurity.Impurity | ||
| import org.apache.spark.rdd.RDD | ||
|
|
||
|
|
||
| /** | ||
| * Learning and dataset metadata for DecisionTree. | ||
| * | ||
| * @param featureArity Map: categorical feature index --> arity. | ||
| * I.e., the feature takes values in {0, ..., arity - 1}. | ||
| */ | ||
| private[tree] class DTMetadata( | ||
| val numFeatures: Int, | ||
| val numExamples: Long, | ||
| val numClasses: Int, | ||
|
||
| val maxBins: Int, | ||
| val featureArity: Map[Int, Int], | ||
| val unorderedFeatures: Set[Int], | ||
| val impurity: Impurity, | ||
| val quantileStrategy: QuantileStrategy) extends Serializable { | ||
|
|
||
| def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) | ||
|
|
||
| def isClassification: Boolean = numClasses >= 2 | ||
|
|
||
| def isMulticlass: Boolean = numClasses > 2 | ||
|
|
||
| def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0) | ||
|
|
||
| def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex) | ||
|
|
||
| def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) | ||
|
|
||
| } | ||
|
|
||
| private[tree] object DTMetadata { | ||
|
|
||
| def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DTMetadata = { | ||
|
|
||
| val numFeatures = input.take(1)(0).features.size | ||
| val numExamples = input.count() | ||
| val numClasses = strategy.algo match { | ||
| case Classification => strategy.numClassesForClassification | ||
| case Regression => 0 | ||
| } | ||
|
|
||
| val maxBins = math.min(strategy.maxBins, numExamples).toInt | ||
|
|
||
| val unorderedFeatures = new mutable.HashSet[Int]() | ||
| if (numClasses > 2) { | ||
| strategy.categoricalFeaturesInfo.foreach { case (f, k) => | ||
| val numUnorderedBins = (1 << k - 1) - 1 | ||
|
||
| if (numUnorderedBins < maxBins) { | ||
| unorderedFeatures.add(f) | ||
| } else { | ||
| // TODO: Allow this case, where we simply will know nothing about some categories? | ||
| require(k < maxBins, "maxBins should be greater than max categories " + | ||
| "in categorical features") | ||
| } | ||
| } | ||
| } else { | ||
| strategy.categoricalFeaturesInfo.foreach { case (f, k) => | ||
|
||
| require(k < maxBins, "maxBins should be greater than max categories " + | ||
| "in categorical features") | ||
| } | ||
| } | ||
|
|
||
| new DTMetadata(numFeatures, numExamples, numClasses, maxBins, | ||
| strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, | ||
| strategy.impurity, strategy.quantileCalculationStrategy) | ||
| } | ||
|
|
||
| } | ||
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: DecisionTreeMetadata, TreeMetadata or Metadata as class name alternatives.