Skip to content
Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
710f257
initial fpm
YY-OnCall Jun 15, 2016
b8d117d
add ass rule
YY-OnCall Jun 16, 2016
bc9b830
Merge remote-tracking branch 'upstream/master' into mlfpm
hhbyyh Jul 10, 2016
5764e59
Merge remote-tracking branch 'upstream/master' into mlfpm
hhbyyh Jul 11, 2016
1ca4e45
Merge remote-tracking branch 'upstream/master' into mlfpm
hhbyyh Jul 11, 2016
c30dd7c
refine code
hhbyyh Jul 11, 2016
c0b5291
Merge remote-tracking branch 'upstream/master' into mlfpm
hhbyyh Oct 8, 2016
77e1f93
Merge remote-tracking branch 'upstream/master' into mlfpm
hhbyyh Oct 10, 2016
2f1a08c
add ut
hhbyyh Oct 10, 2016
388adaf
Merge remote-tracking branch 'upstream/master' into mlfpm
YY-OnCall Oct 10, 2016
7eabd31
Merge remote-tracking branch 'upstream/master' into mlfpm
hhbyyh Oct 14, 2016
3730e1b
Merge remote-tracking branch 'upstream/master' into mlfpm
YY-OnCall Oct 19, 2016
3a000df
Merge remote-tracking branch 'upstream/master' into mlfpm
hhbyyh Nov 2, 2016
e5574be
refine and add unit test
hhbyyh Nov 2, 2016
ed1f91e
Merge remote-tracking branch 'upstream/master' into mlfpm
YY-OnCall Dec 16, 2016
2afdf48
Merge branch 'mlfpm' of https://github.com/hhbyyh/spark into mlfpm
YY-OnCall Dec 16, 2016
63eaf08
fpgrowth version
YY-OnCall Dec 16, 2016
0837b55
Merge remote-tracking branch 'upstream/master' into mlfpm
YY-OnCall Jan 19, 2017
3273b76
add numPartitions and change version
YY-OnCall Jan 19, 2017
d4d8ac2
Merge remote-tracking branch 'upstream/master' into mlfpm
YY-OnCall Jan 24, 2017
fbac43f
Merge remote-tracking branch 'upstream/master' into mlfpm
YY-OnCall Jan 31, 2017
57c9437
use association rules to transform
YY-OnCall Jan 31, 2017
049e1a3
make AssociationRules private and use join
YY-OnCall Feb 8, 2017
5d7881c
Merge remote-tracking branch 'upstream/master' into mlfpm
YY-OnCall Feb 15, 2017
e141776
make assocationrules static and support generic
YY-OnCall Feb 15, 2017
140885d
Merge remote-tracking branch 'upstream/master' into mlfpm
YY-OnCall Feb 19, 2017
06e5c69
Merge remote-tracking branch 'upstream/master' into mlfpm
YY-OnCall Feb 20, 2017
dfdf85d
transform optimize and code refine
YY-OnCall Feb 20, 2017
453ae5b
Merge remote-tracking branch 'upstream/master' into mlfpm
YY-OnCall Feb 20, 2017
5050bd3
Merge remote-tracking branch 'upstream/master' into mlfpm
YY-OnCall Feb 21, 2017
d8e4884
add numpartitions
YY-OnCall Feb 23, 2017
bfcef4a
remove sort
YY-OnCall Feb 24, 2017
3d7ed0b
back to broadcast
YY-OnCall Feb 25, 2017
9940c47
ut minor rename
YY-OnCall Feb 25, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/AssociationRules.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.fpm

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, Params}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules}
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}

/**
* :: Experimental ::
*
* Generates association rules from frequent itemsets ("items", "freq"). This method only generates
* association rules which have a single item as the consequent.
*/
@Since("2.1.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be 2.2.0

@Experimental
class AssociationRules(override val uid: String) extends Params {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since AssociationRules transform DataFrame freqItemsets to DataFrame rules, can it be a subclass of Transformer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

freqItemsets and rules does not have a one-to-one mapping relation and will probably violates the primitives of Transformer.


@Since("2.1.0")
def this() = this(Identifiable.randomUID("AssociationRules"))

/**
* Param for items column name. Items must be array of Integers.
* Default: "items"
* @group param
*/
final val itemsCol: Param[String] = new Param[String](this, "itemsCol", "items column name")


/** @group getParam */
@Since("2.1.0")
final def getItemsCol: String = $(itemsCol)

/** @group setParam */
@Since("2.1.0")
def setItemsCol(value: String): this.type = set(itemsCol, value)

/**
* Param for frequency column name. Data type should be Long.
* Default: "freq"
* @group param
*/
final val freqCol: Param[String] = new Param[String](this, "freqCol", "frequency column name")


/** @group getParam */
@Since("2.1.0")
final def getFreqCol: String = $(freqCol)

/** @group setParam */
@Since("2.1.0")
def setFreqCol(value: String): this.type = set(freqCol, value)

/**
* Param for minimum confidence, range [0.0, 1.0].
* @group param
*/
final val minConfidence: DoubleParam = new DoubleParam(this, "minConfidence", "min confidence")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there should be a ParamValidators.inRange(...)


/** @group getParam */
@Since("2.1.0")
final def getMinConfidence: Double = $(minConfidence)

/** @group setParam */
@Since("2.1.0")
def setMinConfidence(value: Double): this.type = set(minConfidence, value)

setDefault(itemsCol -> "items", freqCol -> "freq", minConfidence -> 0.8)

/**
* Computes the association rules with confidence above [[minConfidence]].
* @param freqItemsets DataFrame containing frequent itemset obtained from algorithms like
* [[FPGrowth]]. Users can set itemsCol (frequent itemSet, Array[String])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array[String] confilct with Array[Int] in https://github.com/apache/spark/pull/15415/files#diff-0a641720038f962d333ef38402a02207R41
and is there some way to support general types?

* and freqCol (appearance count, Long) names in the DataFrame.
* @return a DataFrame("antecedent", "consequent", "confidence") containing the association
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

returns a model, not DataFrame

* rules.
*
*/
@Since("2.1.0")
def run(freqItemsets: Dataset[_]): DataFrame = {
Copy link
Contributor

@zhengruifeng zhengruifeng Jan 5, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If inheriting Transformer, here should be override def transform(dataset: Dataset[_]): DataFrame

val freqItemSetRdd = freqItemsets.select($(itemsCol), $(freqCol)).rdd
.map(row => new FreqItemset(row.getSeq[String](0).toArray, row.getLong(1)))

val sqlContext = SparkSession.builder().getOrCreate()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since val sqlContext is of type SparkSession, what about rename it spark?

import sqlContext.implicits._
new MLlibAssociationRules()
.setMinConfidence($(minConfidence))
.run(freqItemSetRdd)
.map(r => (r.antecedent, r.consequent, r.confidence))
.toDF("antecedent", "consequent", "confidence")
}

override def copy(extra: ParamMap): AssociationRules = defaultCopy(extra)

}
232 changes: 232 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.fpm

import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.fpm.{FPGrowth => MLlibFPGrowth, FPGrowthModel => MLlibFPGrowthModel}
import org.apache.spark.sql.{DataFrame, _}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}

/**
* Common params for FPGrowth and FPGrowthModel
*/
private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPredictionCol {

/**
* Validates and transforms the input schema.
* @param schema input schema
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since annotation

SchemaUtils.checkColumnType(schema, $(featuresCol), new ArrayType(StringType, false))
SchemaUtils.appendColumn(schema, $(predictionCol), new ArrayType(StringType, false))
}

/**
* the minimal support level of the frequent pattern
* Default: 0.3
* @group param
*/
@Since("2.2.0")
val minSupport: DoubleParam = new DoubleParam(this, "minSupport",
"the minimal support level of the frequent pattern (Default: 0.3)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also need a ParamValidator here


/** @group getParam */
@Since("2.2.0")
def getMinSupport: Double = $(minSupport)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLLib's FPGrowth have a param numPartitions, will it be included here?

}

/**
* :: Experimental ::
* A parallel FP-growth algorithm to mine frequent itemsets.
*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here or elsewhere, comment that null featuresCol values are ignored during fit() and are treated as empty sets during transform().

* @see [[http://dx.doi.org/10.1145/1454008.1454027 Li et al., PFP: Parallel FP-Growth for Query
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please go ahead and copy the relevant text and links from the Scaladoc string for mllib.fpm.FPGrowth?

* Recommendation]]
*/
@Since("2.2.0")
@Experimental
class FPGrowth @Since("2.2.0") (
@Since("2.2.0") override val uid: String)
extends Estimator[FPGrowthModel] with FPGrowthParams with DefaultParamsWritable {

@Since("2.2.0")
def this() = this(Identifiable.randomUID("FPGrowth"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use lowercase name to match other algs: "fpgrowth"


/** @group setParam */
@Since("2.2.0")
def setMinSupport(value: Double): this.type = set(minSupport, value)
setDefault(minSupport -> 0.3)

/** @group setParam */
@Since("2.2.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I perfer use setInputCol and inputCol instead of this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Let's collect more feedback about it.


/** @group setParam */
@Since("2.2.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, setOutputCol


def fit(dataset: Dataset[_]): FPGrowthModel = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override

val data = dataset.select($(featuresCol)).rdd.map(r => r.getSeq[String](0).toArray)
val parentModel = new MLlibFPGrowth().setMinSupport($(minSupport)).run(data)
copyValues(new FPGrowthModel(uid, parentModel))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setParent

}

@Since("2.2.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}

override def copy(extra: ParamMap): FPGrowth = defaultCopy(extra)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since annotation

}


@Since("2.2.0")
object FPGrowth extends DefaultParamsReadable[FPGrowth] {

@Since("2.2.0")
override def load(path: String): FPGrowth = super.load(path)
}

/**
* :: Experimental ::
* Model fitted by FPGrowth.
*
* @param parentModel a model trained by spark.mllib.fpm.FPGrowth
*/
@Since("2.2.0")
@Experimental
class FPGrowthModel private[ml] (
@Since("2.2.0") override val uid: String,
private val parentModel: MLlibFPGrowthModel[_])
extends Model[FPGrowthModel] with FPGrowthParams with MLWritable {

/**
* minimal confidence for generating Association Rule
* Default: 0.8
* @group param
*/
@Since("2.2.0")
val minConfidence: DoubleParam = new DoubleParam(this, "minConfidence",
"minimal confidence for generating Association Rule (Default: 0.8)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ParamValidator.inRange(0,1,...)

setDefault(minConfidence -> 0.8)

/** @group getParam */
@Since("2.2.0")
def getMinConfidence: Double = $(minConfidence)

/** @group setParam */
@Since("2.2.0")
def setMinConfidence(value: Double): this.type = set(minConfidence, value)

@Since("2.2.0")
override def transform(dataset: Dataset[_]): DataFrame = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some discussion about the behavior of transform here? It seems a new feature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed it's a new feature for making FPGrowthModel a Transformer. I've added it to the PR description to help draw more attention.

val associationRules = getAssociationRules.rdd.map(r =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we worried about recomputing association rules every time we do a transform?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Thanks @aray. I've tried to leverage lazy val to avoid unnecessary computation.

(r.getSeq[String](0), r.getSeq[String](1))
).collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we guard against this being too large?


// For each rule, examine the input items and summarize the consequents
val predictUDF = udf((items: Seq[String]) => associationRules.flatMap( r =>
if (r._1.forall(items.contains(_))) r._2 else Array.empty[String]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if there is a rule {a,b} -> c and I pass this the set {a,b,c} it's going to "predict" c? Also, should we include the confidence number in the result?

).distinct)
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we including featuresCol as a new column here?

}

@Since("2.2.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}

@Since("2.2.0")
override def copy(extra: ParamMap): FPGrowthModel = {
val copied = new FPGrowthModel(uid, parentModel)
copyValues(copied, extra)
}

/**
* Get frequent items fitted by FPGrowth, in the format of DataFrame("items", "freq")
*/
@Since("2.2.0")
def getFreqItems: DataFrame = {
val sqlContext = SparkSession.builder().getOrCreate()
import sqlContext.implicits._
parentModel.freqItemsets.map(f => (f.items.map(_.toString), f.freq))
.toDF("items", "freq")
}

/**
* Get association rules fitted by AssociationRules using the minConfidence, in the format
* of DataFrame("antecedent", "consequent", "confidence")
*/
@Since("2.2.0")
def getAssociationRules: DataFrame = {
val freqItems = getFreqItems

val associationRules = new AssociationRules()
.setMinConfidence($(minConfidence))
.setItemsCol("items")
.setFreqCol("freq")
associationRules.run(freqItems)
}

@Since("2.2.0")
override def write: MLWriter = new FPGrowthModel.FPGrowthModelWriter(this)
}

object FPGrowthModel extends MLReadable[FPGrowthModel] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Since annotation to object

@Since("2.2.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline

override def read: MLReader[FPGrowthModel] = new FPGrowthModelReader

@Since("2.2.0")
override def load(path: String): FPGrowthModel = super.load(path)

/** [[MLWriter]] instance for [[FPGrowthModel]] */
private[FPGrowthModel]
class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this doc line

DefaultParamsWriter.saveMetadata(instance, path, sc)
val dataPath = new Path(path, "data").toString
instance.parentModel.save(sc, dataPath)
}
}

private class FPGrowthModelReader extends MLReader[FPGrowthModel] {

/** Checked against metadata when loading model */
private val className = classOf[FPGrowthModel].getName

override def load(path: String): FPGrowthModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val mllibModel = MLlibFPGrowthModel.load(sc, dataPath)
val model = new FPGrowthModel(metadata.uid, mllibModel)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.fpm

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext

class AssociationRulesSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

test("association rules using String type") {
val freqItems = spark.createDataFrame(Seq(
(Array("a", "b"), 3L),
(Array("a"), 3L),
(Array("b"), 3L))
).toDF("items", "freq")

val associationRules = new AssociationRules()
.setMinConfidence(0.8)
.setItemsCol("items")
.setFreqCol("freq")
val rules = associationRules.run(freqItems)

val expectedRules = spark.createDataFrame(Seq(
(Array("a"), Array("b"), 1.0),
(Array("b"), Array("a"), 1.0))
).toDF("antecedent", "consequent", "confidence")
assert(rules.sort("antecedent").rdd.collect() ===
expectedRules.sort("antecedent").rdd.collect())
}
}
Loading