-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-14503][ML] spark.ml API for FPGrowth #15415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
710f257
b8d117d
bc9b830
5764e59
1ca4e45
c30dd7c
c0b5291
77e1f93
2f1a08c
388adaf
7eabd31
3730e1b
3a000df
e5574be
ed1f91e
2afdf48
63eaf08
0837b55
3273b76
d4d8ac2
fbac43f
57c9437
049e1a3
5d7881c
e141776
140885d
06e5c69
dfdf85d
453ae5b
5050bd3
d8e4884
bfcef4a
3d7ed0b
9940c47
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.fpm | ||
|
|
||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, Params} | ||
| import org.apache.spark.ml.util.Identifiable | ||
| import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules} | ||
| import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset | ||
| import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * | ||
| * Generates association rules from frequent itemsets ("items", "freq"). This method only generates | ||
| * association rules which have a single item as the consequent. | ||
| */ | ||
| @Since("2.1.0") | ||
| @Experimental | ||
| class AssociationRules(override val uid: String) extends Params { | ||
|
||
|
|
||
| @Since("2.1.0") | ||
| def this() = this(Identifiable.randomUID("AssociationRules")) | ||
|
|
||
| /** | ||
| * Param for items column name. Items must be array of Integers. | ||
| * Default: "items" | ||
| * @group param | ||
| */ | ||
| final val itemsCol: Param[String] = new Param[String](this, "itemsCol", "items column name") | ||
|
|
||
|
|
||
| /** @group getParam */ | ||
| @Since("2.1.0") | ||
| final def getItemsCol: String = $(itemsCol) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.1.0") | ||
| def setItemsCol(value: String): this.type = set(itemsCol, value) | ||
|
|
||
| /** | ||
| * Param for frequency column name. Data type should be Long. | ||
| * Default: "freq" | ||
| * @group param | ||
| */ | ||
| final val freqCol: Param[String] = new Param[String](this, "freqCol", "frequency column name") | ||
|
|
||
|
|
||
| /** @group getParam */ | ||
| @Since("2.1.0") | ||
| final def getFreqCol: String = $(freqCol) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.1.0") | ||
| def setFreqCol(value: String): this.type = set(freqCol, value) | ||
|
|
||
| /** | ||
| * Param for minimum confidence, range [0.0, 1.0]. | ||
| * @group param | ||
| */ | ||
| final val minConfidence: DoubleParam = new DoubleParam(this, "minConfidence", "min confidence") | ||
|
||
|
|
||
| /** @group getParam */ | ||
| @Since("2.1.0") | ||
| final def getMinConfidence: Double = $(minConfidence) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.1.0") | ||
| def setMinConfidence(value: Double): this.type = set(minConfidence, value) | ||
|
|
||
| setDefault(itemsCol -> "items", freqCol -> "freq", minConfidence -> 0.8) | ||
|
|
||
| /** | ||
| * Computes the association rules with confidence above [[minConfidence]]. | ||
| * @param freqItemsets DataFrame containing frequent itemset obtained from algorithms like | ||
| * [[FPGrowth]]. Users can set itemsCol (frequent itemSet, Array[String]) | ||
|
||
| * and freqCol (appearance count, Long) names in the DataFrame. | ||
| * @return a DataFrame("antecedent", "consequent", "confidence") containing the association | ||
|
||
| * rules. | ||
| * | ||
| */ | ||
| @Since("2.1.0") | ||
| def run(freqItemsets: Dataset[_]): DataFrame = { | ||
|
||
| val freqItemSetRdd = freqItemsets.select($(itemsCol), $(freqCol)).rdd | ||
| .map(row => new FreqItemset(row.getSeq[String](0).toArray, row.getLong(1))) | ||
|
|
||
| val sqlContext = SparkSession.builder().getOrCreate() | ||
|
||
| import sqlContext.implicits._ | ||
| new MLlibAssociationRules() | ||
| .setMinConfidence($(minConfidence)) | ||
| .run(freqItemSetRdd) | ||
| .map(r => (r.antecedent, r.consequent, r.confidence)) | ||
| .toDF("antecedent", "consequent", "confidence") | ||
| } | ||
|
|
||
| override def copy(extra: ParamMap): AssociationRules = defaultCopy(extra) | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,232 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.fpm | ||
|
|
||
| import org.apache.hadoop.fs.Path | ||
|
|
||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.ml.{Estimator, Model} | ||
| import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params} | ||
| import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.mllib.fpm.{FPGrowth => MLlibFPGrowth, FPGrowthModel => MLlibFPGrowthModel} | ||
| import org.apache.spark.sql.{DataFrame, _} | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.types.{ArrayType, StringType, StructType} | ||
|
|
||
| /** | ||
| * Common params for FPGrowth and FPGrowthModel | ||
| */ | ||
| private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPredictionCol { | ||
|
|
||
| /** | ||
| * Validates and transforms the input schema. | ||
| * @param schema input schema | ||
| * @return output schema | ||
| */ | ||
| protected def validateAndTransformSchema(schema: StructType): StructType = { | ||
|
||
| SchemaUtils.checkColumnType(schema, $(featuresCol), new ArrayType(StringType, false)) | ||
| SchemaUtils.appendColumn(schema, $(predictionCol), new ArrayType(StringType, false)) | ||
| } | ||
|
|
||
| /** | ||
| * the minimal support level of the frequent pattern | ||
| * Default: 0.3 | ||
| * @group param | ||
| */ | ||
| @Since("2.2.0") | ||
| val minSupport: DoubleParam = new DoubleParam(this, "minSupport", | ||
| "the minimal support level of the frequent pattern (Default: 0.3)") | ||
|
||
|
|
||
| /** @group getParam */ | ||
| @Since("2.2.0") | ||
| def getMinSupport: Double = $(minSupport) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MLLib's |
||
| } | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * A parallel FP-growth algorithm to mine frequent itemsets. | ||
| * | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here or elsewhere, comment that null featuresCol values are ignored during fit() and are treated as empty sets during transform(). |
||
| * @see [[http://dx.doi.org/10.1145/1454008.1454027 Li et al., PFP: Parallel FP-Growth for Query | ||
|
||
| * Recommendation]] | ||
| */ | ||
| @Since("2.2.0") | ||
| @Experimental | ||
| class FPGrowth @Since("2.2.0") ( | ||
| @Since("2.2.0") override val uid: String) | ||
| extends Estimator[FPGrowthModel] with FPGrowthParams with DefaultParamsWritable { | ||
|
|
||
| @Since("2.2.0") | ||
| def this() = this(Identifiable.randomUID("FPGrowth")) | ||
|
||
|
|
||
| /** @group setParam */ | ||
| @Since("2.2.0") | ||
| def setMinSupport(value: Double): this.type = set(minSupport, value) | ||
| setDefault(minSupport -> 0.3) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.2.0") | ||
| def setFeaturesCol(value: String): this.type = set(featuresCol, value) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I perfer use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. Let's collect more feedback about it. |
||
|
|
||
| /** @group setParam */ | ||
| @Since("2.2.0") | ||
| def setPredictionCol(value: String): this.type = set(predictionCol, value) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto, |
||
|
|
||
| def fit(dataset: Dataset[_]): FPGrowthModel = { | ||
|
||
| val data = dataset.select($(featuresCol)).rdd.map(r => r.getSeq[String](0).toArray) | ||
| val parentModel = new MLlibFPGrowth().setMinSupport($(minSupport)).run(data) | ||
| copyValues(new FPGrowthModel(uid, parentModel)) | ||
|
||
| } | ||
|
|
||
| @Since("2.2.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| validateAndTransformSchema(schema) | ||
| } | ||
|
|
||
| override def copy(extra: ParamMap): FPGrowth = defaultCopy(extra) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since annotation |
||
| } | ||
|
|
||
|
|
||
| @Since("2.2.0") | ||
| object FPGrowth extends DefaultParamsReadable[FPGrowth] { | ||
|
|
||
| @Since("2.2.0") | ||
| override def load(path: String): FPGrowth = super.load(path) | ||
| } | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Model fitted by FPGrowth. | ||
| * | ||
| * @param parentModel a model trained by spark.mllib.fpm.FPGrowth | ||
| */ | ||
| @Since("2.2.0") | ||
| @Experimental | ||
| class FPGrowthModel private[ml] ( | ||
| @Since("2.2.0") override val uid: String, | ||
| private val parentModel: MLlibFPGrowthModel[_]) | ||
| extends Model[FPGrowthModel] with FPGrowthParams with MLWritable { | ||
|
|
||
| /** | ||
| * minimal confidence for generating Association Rule | ||
| * Default: 0.8 | ||
| * @group param | ||
| */ | ||
| @Since("2.2.0") | ||
| val minConfidence: DoubleParam = new DoubleParam(this, "minConfidence", | ||
| "minimal confidence for generating Association Rule (Default: 0.8)") | ||
|
||
| setDefault(minConfidence -> 0.8) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.2.0") | ||
| def getMinConfidence: Double = $(minConfidence) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.2.0") | ||
| def setMinConfidence(value: Double): this.type = set(minConfidence, value) | ||
|
|
||
| @Since("2.2.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there some discussion about the behavior of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed it's a new feature for making FPGrowthModel a Transformer. I've added it to the PR description to help draw more attention. |
||
| val associationRules = getAssociationRules.rdd.map(r => | ||
|
||
| (r.getSeq[String](0), r.getSeq[String](1)) | ||
| ).collect() | ||
|
||
|
|
||
| // For each rule, examine the input items and summarize the consequents | ||
| val predictUDF = udf((items: Seq[String]) => associationRules.flatMap( r => | ||
| if (r._1.forall(items.contains(_))) r._2 else Array.empty[String] | ||
|
||
| ).distinct) | ||
| dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we including |
||
| } | ||
|
|
||
| @Since("2.2.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| validateAndTransformSchema(schema) | ||
| } | ||
|
|
||
| @Since("2.2.0") | ||
| override def copy(extra: ParamMap): FPGrowthModel = { | ||
| val copied = new FPGrowthModel(uid, parentModel) | ||
| copyValues(copied, extra) | ||
| } | ||
|
|
||
| /** | ||
| * Get frequent items fitted by FPGrowth, in the format of DataFrame("items", "freq") | ||
| */ | ||
| @Since("2.2.0") | ||
| def getFreqItems: DataFrame = { | ||
| val sqlContext = SparkSession.builder().getOrCreate() | ||
| import sqlContext.implicits._ | ||
| parentModel.freqItemsets.map(f => (f.items.map(_.toString), f.freq)) | ||
| .toDF("items", "freq") | ||
| } | ||
|
|
||
| /** | ||
| * Get association rules fitted by AssociationRules using the minConfidence, in the format | ||
| * of DataFrame("antecedent", "consequent", "confidence") | ||
| */ | ||
| @Since("2.2.0") | ||
| def getAssociationRules: DataFrame = { | ||
| val freqItems = getFreqItems | ||
|
|
||
| val associationRules = new AssociationRules() | ||
| .setMinConfidence($(minConfidence)) | ||
| .setItemsCol("items") | ||
| .setFreqCol("freq") | ||
| associationRules.run(freqItems) | ||
| } | ||
|
|
||
| @Since("2.2.0") | ||
| override def write: MLWriter = new FPGrowthModel.FPGrowthModelWriter(this) | ||
| } | ||
|
|
||
| object FPGrowthModel extends MLReadable[FPGrowthModel] { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add Since annotation to object |
||
| @Since("2.2.0") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. newline |
||
| override def read: MLReader[FPGrowthModel] = new FPGrowthModelReader | ||
|
|
||
| @Since("2.2.0") | ||
| override def load(path: String): FPGrowthModel = super.load(path) | ||
|
|
||
| /** [[MLWriter]] instance for [[FPGrowthModel]] */ | ||
| private[FPGrowthModel] | ||
| class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter { | ||
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| // Save metadata and Params | ||
|
||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| val dataPath = new Path(path, "data").toString | ||
| instance.parentModel.save(sc, dataPath) | ||
| } | ||
| } | ||
|
|
||
| private class FPGrowthModelReader extends MLReader[FPGrowthModel] { | ||
|
|
||
| /** Checked against metadata when loading model */ | ||
| private val className = classOf[FPGrowthModel].getName | ||
|
|
||
| override def load(path: String): FPGrowthModel = { | ||
| val metadata = DefaultParamsReader.loadMetadata(path, sc, className) | ||
| val dataPath = new Path(path, "data").toString | ||
| val mllibModel = MLlibFPGrowthModel.load(sc, dataPath) | ||
| val model = new FPGrowthModel(metadata.uid, mllibModel) | ||
| DefaultParamsReader.getAndSetParams(model, metadata) | ||
| model | ||
| } | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.spark.ml.fpm | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.ml.util.DefaultReadWriteTest | ||
| import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
|
|
||
| class AssociationRulesSuite | ||
| extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { | ||
|
|
||
| test("association rules using String type") { | ||
| val freqItems = spark.createDataFrame(Seq( | ||
| (Array("a", "b"), 3L), | ||
| (Array("a"), 3L), | ||
| (Array("b"), 3L)) | ||
| ).toDF("items", "freq") | ||
|
|
||
| val associationRules = new AssociationRules() | ||
| .setMinConfidence(0.8) | ||
| .setItemsCol("items") | ||
| .setFreqCol("freq") | ||
| val rules = associationRules.run(freqItems) | ||
|
|
||
| val expectedRules = spark.createDataFrame(Seq( | ||
| (Array("a"), Array("b"), 1.0), | ||
| (Array("b"), Array("a"), 1.0)) | ||
| ).toDF("antecedent", "consequent", "confidence") | ||
| assert(rules.sort("antecedent").rdd.collect() === | ||
| expectedRules.sort("antecedent").rdd.collect()) | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be 2.2.0