Skip to content
8 changes: 4 additions & 4 deletions docs/mllib-ensembles.md
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ val (trainingData, testData) = (splits(0), splits(1))

// Train a GradientBoostedTrees model.
// The defaultParams for Classification use LogLoss by default.
val boostingStrategy = BoostingStrategy.defaultParams("Classification")
val boostingStrategy = BoostingStrategy.defaultParams(Algo.Classification)
boostingStrategy.numIterations = 3 // Note: Use more iterations in practice.
boostingStrategy.treeStrategy.numClassesForClassification = 2
boostingStrategy.treeStrategy.maxDepth = 5
Expand Down Expand Up @@ -506,7 +506,7 @@ JavaRDD<LabeledPoint> testData = splits[1];

// Train a GradientBoostedTrees model.
// The defaultParams for Classification use LogLoss by default.
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification");
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Classification);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work in Java. You need at least "Algo.Classification()`.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that I am a little careless about the docs of Java version. I will update it.

boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice.
boostingStrategy.getTreeStrategy().setNumClassesForClassification(2);
boostingStrategy.getTreeStrategy().setMaxDepth(5);
Expand Down Expand Up @@ -564,7 +564,7 @@ val (trainingData, testData) = (splits(0), splits(1))

// Train a GradientBoostedTrees model.
// The defaultParams for Regression use SquaredError by default.
val boostingStrategy = BoostingStrategy.defaultParams("Regression")
val boostingStrategy = BoostingStrategy.defaultParams(Algo.Regression)
boostingStrategy.numIterations = 3 // Note: Use more iterations in practice.
boostingStrategy.treeStrategy.maxDepth = 5
// Empty categoricalFeaturesInfo indicates all features are continuous.
Expand Down Expand Up @@ -614,7 +614,7 @@ JavaRDD<LabeledPoint> testData = splits[1];

// Train a GradientBoostedTrees model.
// The defaultParams for Regression use SquaredError by default.
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression");
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Regression);
boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice.
boostingStrategy.getTreeStrategy().setMaxDepth(5);
// Empty categoricalFeaturesInfo indicates all features are continuous.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,25 @@ object BoostingStrategy {
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
}
}

/**
* Returns default configuration for the boosting algorithm
* @param algo Learning goal. Supported:
* [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
* @return Configuration for boosting algorithm
*/
def defaultParams(algo: Algo): BoostingStrategy = {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the defaultParams(algoStr: String) to use the implementation here instead of maintaining two copies.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you mean that remove the defaultParams(algoStr: String) and with the defaultParams(algo: Algo) instead?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or change the defaultParams(algo: String) to defaultParams(algoStr: String)?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I thought the param name is algoStr instead of algo. We don't need to rename the param. I was suggesting that we can keep defaultParams(algo: String) but inside it call defaultParams(algo: Algo) so we don't need to maintain two copies of the implementation.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Please check the new commits. Thanks.

val treeStragtegy = Strategy.defaultStategy(algo)
treeStragtegy.maxDepth = 3
algo match {
case Algo.Classification =>
treeStragtegy.numClasses = 2
new BoostingStrategy(treeStragtegy, LogLoss)
case Algo.Regression =>
new BoostingStrategy(treeStragtegy, SquaredError)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,17 @@ object Strategy {
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
numClasses = 0)
}

/**
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
* @param algo Algo.Classification or Algo.Regression
*/
def defaultStategy(algo: Algo): Strategy = algo match {
case Algo.Classification =>
new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
numClasses = 2)
case Algo.Regression =>
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
numClasses = 0)
}
}