Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ import org.apache.spark.sql.{SQLContext, DataFrame}
* {{{
* ./bin/run-example ml.DecisionTreeExample [options]
* }}}
* Note that Decision Trees can take a large amount of memory. If the run-example command above
* fails, try running via spark-submit and specifying the amount of memory as at least 1g.
* For local mode, run
* {{{
* ./bin/spark-submit --class org.apache.spark.examples.ml.DecisionTreeExample --driver-memory 1g
* [examples JAR path] [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
object DecisionTreeExample {
Expand All @@ -70,7 +77,7 @@ object DecisionTreeExample {
val parser = new OptionParser[Params]("DecisionTreeExample") {
head("DecisionTreeExample: an example decision tree app.")
opt[String]("algo")
.text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
.text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
.action((x, c) => c.copy(algo = x))
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
Expand Down Expand Up @@ -222,18 +229,23 @@ object DecisionTreeExample {
// (1) For classification, re-index classes.
val labelColName = if (algo == "classification") "indexedLabel" else "label"
if (algo == "classification") {
val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName)
val labelIndexer = new StringIndexer()
.setInputCol("labelString")
.setOutputCol(labelColName)
stages += labelIndexer
}
// (2) Identify categorical features using VectorIndexer.
// Features with more than maxCategories values will be treated as continuous.
val featuresIndexer = new VectorIndexer().setInputCol("features")
.setOutputCol("indexedFeatures").setMaxCategories(10)
val featuresIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(10)
stages += featuresIndexer
// (3) Learn DecisionTree
val dt = algo match {
case "classification" =>
new DecisionTreeClassifier().setFeaturesCol("indexedFeatures")
new DecisionTreeClassifier()
.setFeaturesCol("indexedFeatures")
.setLabelCol(labelColName)
.setMaxDepth(params.maxDepth)
.setMaxBins(params.maxBins)
Expand All @@ -242,7 +254,8 @@ object DecisionTreeExample {
.setCacheNodeIds(params.cacheNodeIds)
.setCheckpointInterval(params.checkpointInterval)
case "regression" =>
new DecisionTreeRegressor().setFeaturesCol("indexedFeatures")
new DecisionTreeRegressor()
.setFeaturesCol("indexedFeatures")
.setLabelCol(labelColName)
.setMaxDepth(params.maxDepth)
.setMaxBins(params.maxBins)
Expand Down
114 changes: 114 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType

/**
* Params for [[IDF]] and [[IDFModel]].
*/
private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol {

/**
* The minimum of documents in which a term should appear.
* @group param
*/
final val minDocFreq = new IntParam(
this, "minDocFreq", "minimum of documents in which a term should appear for filtering")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add

setDefault(minDocFreq -> 0)

/** @group getParam */
def getMinDocFreq: Int = getOrDefault(minDocFreq)

/** @group setParam */
def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)

/**
* Validate and transform the input schema.
*/
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = extractParamMap(paramMap)
SchemaUtils.checkColumnType(schema, map(inputCol), new VectorUDT)
SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT)
}
}

/**
* :: AlphaComponent ::
* Compute the Inverse Document Frequency (IDF) given a collection of documents.
*/
@AlphaComponent
final class IDF extends Estimator[IDFModel] with IDFBase {

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

override def fit(dataset: DataFrame, paramMap: ParamMap): IDFModel = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = extractParamMap(paramMap)
val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
val idf = new feature.IDF(getMinDocFreq).fit(input)
val model = new IDFModel(this, map, idf)
Params.inheritValues(map, this, model)
model
}

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}

/**
* :: AlphaComponent ::
* Model fitted by [[IDF]].
*/
@AlphaComponent
class IDFModel private[ml] (
override val parent: IDF,
override val fittingParamMap: ParamMap,
idfModel: feature.IDFModel)
extends Model[IDFModel] with IDFBase {

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = extractParamMap(paramMap)
val idf = udf { vec: Vector => idfModel.transform(vec) }
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

udf { idfModel.transform _ }

dataset.withColumn(map(outputCol), idf(col(map(inputCol))))
}

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
def setMaxDepth(value: Int): this.type = {
require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value")
set(maxDepth, value)
this.asInstanceOf[this.type]
this
}

/** @group getParam */
Expand Down Expand Up @@ -283,7 +283,7 @@ private[ml] trait TreeRegressorParams extends Params {
def getImpurity: String = getOrDefault(impurity)

/** Convert new impurity to old impurity. */
protected def getOldImpurity: OldImpurity = {
private[ml] def getOldImpurity: OldImpurity = {
getImpurity match {
case "variance" => OldVariance
case _ =>
Expand Down
7 changes: 4 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ sealed trait Split extends Serializable {
private[tree] def toOld: OldSplit
}

private[ml] object Split {
private[tree] object Split {

def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = {
oldSplit.featureType match {
Expand All @@ -58,7 +58,7 @@ private[ml] object Split {
* left. Otherwise, it goes right.
* @param numCategories Number of categories for this feature.
*/
final class CategoricalSplit(
final class CategoricalSplit private[ml] (
override val featureIndex: Int,
leftCategories: Array[Double],
private val numCategories: Int)
Expand Down Expand Up @@ -130,7 +130,8 @@ final class CategoricalSplit(
* @param threshold If the feature value is <= this threshold, then the split goes left.
* Otherwise, it goes right.
*/
final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split {
final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double)
extends Split {

override private[ml] def shouldGoLeft(features: Vector): Boolean = {
features(featureIndex) <= threshold
Expand Down
112 changes: 112 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

class IDFSuite extends FunSuite with MLlibTestSparkContext {

@transient var sqlContext: SQLContext = _

override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
}

def getResultFromDF(result: DataFrame): Array[Vector] = {
result.select("idf_value").collect().map {
case Row(features: Vector) => features
}
}

def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = {
assert((lhs, rhs).zipped.forall { (vector1, vector2) =>
vector1 ~== vector2 absTol 1E-5
}, "The vector value is not correct after IDF.")
}

def getResultFromVector(dataSet: Array[Vector], model: Vector): Array[Vector] = {
dataSet.map {
case data: DenseVector =>
val res = data.toArray.zip(model.toArray).map { case (x, y) => x * y }
Vectors.dense(res)
case data: SparseVector =>
val res = data.indices.zip(data.values).map { case (id, value) =>
(id, value * model(id))
}
Vectors.sparse(data.size, res)
}
}

test("Normalization with default parameter") {
val numOfFeatures = 4
val data = Array(
Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)),
Vectors.dense(0.0, 1.0, 2.0, 3.0),
Vectors.sparse(numOfFeatures, Array(1), Array(1.0)))
val numOfData = data.length
val idf = Vectors.dense(Array(0, 3, 1, 2).map { x =>
math.log((numOfData + 1.0) / (x + 1.0))
})
val expected = getResultFromVector(data, idf)

val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")

val idfModel = new IDF()
.setInputCol("features")
.setOutputCol("idf_value")
.fit(df)

idfModel.transform(df).select("idf_value", "expected").collect().foreach {
case Row(x: Vector, y: Vector) =>
assert(x ~== y absTol 1e-5)
}
}

test("Normalization with setter") {
val numOfFeatures = 4
val data = Array(
Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)),
Vectors.dense(0.0, 1.0, 2.0, 3.0),
Vectors.sparse(numOfFeatures, Array(1), Array(1.0)))
val numOfData = data.length

val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val dataFrame = sc.parallelize(data, 2).map(Tuple1.apply).toDF("features")

val idfModel = new IDF()
.setInputCol("features")
.setOutputCol("idf_value")
.setMinDocFreq(1)
.fit(dataFrame)

val expectedModel = Vectors.dense(Array(0, 3, 1, 2).map { x =>
if (x > 0) math.log((numOfData + 1.0) / (x + 1.0)) else 0
})

assertValues(
getResultFromDF(idfModel.transform(dataFrame)),
getResultFromVector(data, expectedModel))
}
}