-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-13969][ML] Add FeatureHasher transformer #18513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6ab19a9
ebd2cbf
ba255bf
0be1e65
2f3ea21
7d678fb
6057277
9edb3bd
8c5cb30
b580a5c
990b816
a91b53f
d6a3117
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,196 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.feature | ||
|
|
||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.ml.Transformer | ||
| import org.apache.spark.ml.attribute.AttributeGroup | ||
| import org.apache.spark.ml.linalg.Vectors | ||
| import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} | ||
| import org.apache.spark.ml.param.shared.{HasInputCols, HasOutputCol} | ||
| import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} | ||
| import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF} | ||
| import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.util.Utils | ||
| import org.apache.spark.util.collection.OpenHashMap | ||
|
|
||
| /** | ||
| * Feature hashing projects a set of categorical or numerical features into a feature vector of | ||
| * specified dimension (typically substantially smaller than that of the original feature | ||
| * space). This is done using the hashing trick (https://en.wikipedia.org/wiki/Feature_hashing) | ||
| * to map features to indices in the feature vector. | ||
| * | ||
| * The [[FeatureHasher]] transformer operates on multiple columns. Each column may contain either | ||
| * numeric or categorical features. Behavior and handling of column data types is as follows: | ||
| * -Numeric columns: For numeric features, the hash value of the column name is used to map the | ||
| * feature value to its index in the feature vector. Numeric features are never | ||
| * treated as categorical, even when they are integers. You must explicitly | ||
| * convert numeric columns containing categorical features to strings first. | ||
| * -String columns: For categorical features, the hash value of the string "column_name=value" | ||
| * is used to map to the vector index, with an indicator value of `1.0`. | ||
| * Thus, categorical features are "one-hot" encoded | ||
| * (similarly to using [[OneHotEncoder]] with `dropLast=false`). | ||
| * -Boolean columns: Boolean values are treated in the same way as string columns. That is, | ||
| * boolean features are represented as "column_name=true" or "column_name=false", | ||
| * with an indicator value of `1.0`. | ||
| * | ||
| * Null (missing) values are ignored (implicitly zero in the resulting feature vector). | ||
| * | ||
| * Since a simple modulo is used to transform the hash function to a vector index, | ||
| * it is advisable to use a power of two as the numFeatures parameter; | ||
| * otherwise the features will not be mapped evenly to the vector indices. | ||
| * | ||
| * {{{ | ||
| * val df = Seq( | ||
| * (2.0, true, "1", "foo"), | ||
| * (3.0, false, "2", "bar") | ||
| * ).toDF("real", "bool", "stringNum", "string") | ||
| * | ||
| * val hasher = new FeatureHasher() | ||
| * .setInputCols("real", "bool", "stringNum", "num") | ||
| * .setOutputCol("features") | ||
| * | ||
| * hasher.transform(df).show() | ||
| * | ||
| * +----+-----+---------+------+--------------------+ | ||
| * |real| bool|stringNum|string| features| | ||
| * +----+-----+---------+------+--------------------+ | ||
| * | 2.0| true| 1| foo|(262144,[51871,63...| | ||
| * | 3.0|false| 2| bar|(262144,[6031,806...| | ||
| * +----+-----+---------+------+--------------------+ | ||
| * }}} | ||
| */ | ||
| @Experimental | ||
| @Since("2.3.0") | ||
| class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transformer | ||
| with HasInputCols with HasOutputCol with DefaultParamsWritable { | ||
|
|
||
| @Since("2.3.0") | ||
| def this() = this(Identifiable.randomUID("featureHasher")) | ||
|
|
||
| /** | ||
| * Number of features. Should be greater than 0. | ||
| * (default = 2^18^) | ||
| * @group param | ||
| */ | ||
| @Since("2.3.0") | ||
| val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)", | ||
| ParamValidators.gt(0)) | ||
|
|
||
| setDefault(numFeatures -> (1 << 18)) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.3.0") | ||
| def getNumFeatures: Int = $(numFeatures) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setNumFeatures(value: Int): this.type = set(numFeatures, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setInputCols(values: String*): this.type = setInputCols(values.toArray) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setInputCols(value: Array[String]): this.type = set(inputCols, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
|
||
| @Since("2.3.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| val hashFunc: Any => Int = OldHashingTF.murmur3Hash | ||
| val n = $(numFeatures) | ||
| val localInputCols = $(inputCols) | ||
|
|
||
| val outputSchema = transformSchema(dataset.schema) | ||
| val realFields = outputSchema.fields.filter { f => | ||
| f.dataType.isInstanceOf[NumericType] | ||
| }.map(_.name).toSet | ||
|
|
||
| def getDouble(x: Any): Double = { | ||
| x match { | ||
| case n: java.lang.Number => | ||
| n.doubleValue() | ||
| case other => | ||
| // will throw ClassCastException if it cannot be cast, as would row.getDouble | ||
| other.asInstanceOf[Double] | ||
| } | ||
| } | ||
|
|
||
| val hashFeatures = udf { row: Row => | ||
| val map = new OpenHashMap[Int, Double]() | ||
| localInputCols.foreach { colName => | ||
| val fieldIndex = row.fieldIndex(colName) | ||
| if (!row.isNullAt(fieldIndex)) { | ||
| val (rawIdx, value) = if (realFields(colName)) { | ||
| // numeric values are kept as is, with vector index based on hash of "column_name" | ||
| val value = getDouble(row.get(fieldIndex)) | ||
| val hash = hashFunc(colName) | ||
| (hash, value) | ||
| } else { | ||
| // string and boolean values are treated as categorical, with an indicator value of 1.0 | ||
| // and vector index based on hash of "column_name=value" | ||
| val value = row.get(fieldIndex).toString | ||
| val fieldName = s"$colName=$value" | ||
| val hash = hashFunc(fieldName) | ||
| (hash, 1.0) | ||
| } | ||
| val idx = Utils.nonNegativeMod(rawIdx, n) | ||
| map.changeValue(idx, value, v => v + value) | ||
| } | ||
| } | ||
| Vectors.sparse(n, map.toSeq) | ||
| } | ||
|
|
||
| val metadata = outputSchema($(outputCol)).metadata | ||
| dataset.select( | ||
| col("*"), | ||
| hashFeatures(struct($(inputCols).map(col): _*)).as($(outputCol), metadata)) | ||
| } | ||
|
|
||
| @Since("2.3.0") | ||
| override def copy(extra: ParamMap): FeatureHasher = defaultCopy(extra) | ||
|
|
||
| @Since("2.3.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| val fields = schema($(inputCols).toSet) | ||
| fields.foreach { fieldSchema => | ||
| val dataType = fieldSchema.dataType | ||
| val fieldName = fieldSchema.name | ||
| require(dataType.isInstanceOf[NumericType] || | ||
| dataType.isInstanceOf[StringType] || | ||
| dataType.isInstanceOf[BooleanType], | ||
| s"FeatureHasher requires columns to be of NumericType, BooleanType or StringType. " + | ||
| s"Column $fieldName was $dataType") | ||
| } | ||
| val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that we didn't store
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feature hashing doesn't keep the feature -> idx mapping for memory efficiency, so by extension it won't keep attribute info. This is by design, and the tradeoff is speed & efficiency vs. not being able to do the reverse mapping (or knowing the cardinality of each feature, for example). If users want to keep the mapping & attribute info, then of course they can just use one-hot encoding and vector assembler.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MLnick Thanks for clarifying. |
||
| SchemaUtils.appendColumn(schema, attrGroup.toStructField()) | ||
| } | ||
| } | ||
|
|
||
| @Since("2.3.0") | ||
| object FeatureHasher extends DefaultParamsReadable[FeatureHasher] { | ||
|
|
||
| @Since("2.3.0") | ||
| override def load(path: String): FeatureHasher = super.load(path) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,193 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.feature | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.ml.attribute.AttributeGroup | ||
| import org.apache.spark.ml.linalg.{Vector, Vectors} | ||
| import org.apache.spark.ml.param.ParamsSuite | ||
| import org.apache.spark.ml.util.DefaultReadWriteTest | ||
| import org.apache.spark.ml.util.TestingUtils._ | ||
| import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
| import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder | ||
| import org.apache.spark.sql.functions.col | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| class FeatureHasherSuite extends SparkFunSuite | ||
| with MLlibTestSparkContext | ||
| with DefaultReadWriteTest { | ||
|
|
||
| import testImplicits._ | ||
|
|
||
| import HashingTFSuite.murmur3FeatureIdx | ||
|
|
||
| implicit private val vectorEncoder = ExpressionEncoder[Vector]() | ||
|
|
||
| test("params") { | ||
| ParamsSuite.checkParams(new FeatureHasher) | ||
| } | ||
|
|
||
| test("specify input cols using varargs or array") { | ||
| val featureHasher1 = new FeatureHasher() | ||
| .setInputCols("int", "double", "float", "stringNum", "string") | ||
| val featureHasher2 = new FeatureHasher() | ||
| .setInputCols(Array("int", "double", "float", "stringNum", "string")) | ||
| assert(featureHasher1.getInputCols === featureHasher2.getInputCols) | ||
| } | ||
|
|
||
| test("feature hashing") { | ||
| val df = Seq( | ||
| (2.0, true, "1", "foo"), | ||
| (3.0, false, "2", "bar") | ||
| ).toDF("real", "bool", "stringNum", "string") | ||
|
|
||
| val n = 100 | ||
| val hasher = new FeatureHasher() | ||
| .setInputCols("real", "bool", "stringNum", "string") | ||
| .setOutputCol("features") | ||
| .setNumFeatures(n) | ||
| val output = hasher.transform(df) | ||
| val attrGroup = AttributeGroup.fromStructField(output.schema("features")) | ||
| assert(attrGroup.numAttributes === Some(n)) | ||
|
|
||
| val features = output.select("features").as[Vector].collect() | ||
| // Assume perfect hash on field names | ||
| def idx: Any => Int = murmur3FeatureIdx(n) | ||
| // check expected indices | ||
| val expected = Seq( | ||
| Vectors.sparse(n, Seq((idx("real"), 2.0), (idx("bool=true"), 1.0), | ||
| (idx("stringNum=1"), 1.0), (idx("string=foo"), 1.0))), | ||
| Vectors.sparse(n, Seq((idx("real"), 3.0), (idx("bool=false"), 1.0), | ||
| (idx("stringNum=2"), 1.0), (idx("string=bar"), 1.0))) | ||
| ) | ||
| assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 }) | ||
| } | ||
|
|
||
| test("hashing works for all numeric types") { | ||
| val df = Seq(5.0, 10.0, 15.0).toDF("real") | ||
|
|
||
| val hasher = new FeatureHasher() | ||
| .setInputCols("real") | ||
| .setOutputCol("features") | ||
|
|
||
| val expectedResult = hasher.transform(df).select("features").as[Vector].collect() | ||
| // check all numeric types work as expected. String & boolean types are tested in default case | ||
| val types = | ||
| Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) | ||
| types.foreach { t => | ||
| val castDF = df.select(col("real").cast(t)) | ||
| val castResult = hasher.transform(castDF).select("features").as[Vector].collect() | ||
| withClue(s"FeatureHasher works for all numeric types (testing $t): ") { | ||
| assert(castResult.zip(expectedResult).forall { case (actual, expected) => | ||
| actual ~== expected absTol 1e-14 | ||
| }) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| test("invalid input type should fail") { | ||
| val df = Seq( | ||
| Vectors.dense(1), | ||
| Vectors.dense(2) | ||
| ).toDF("vec") | ||
|
|
||
| intercept[IllegalArgumentException] { | ||
| new FeatureHasher().setInputCols("vec").transform(df) | ||
| } | ||
| } | ||
|
|
||
| test("hash collisions sum feature values") { | ||
| val df = Seq( | ||
| (1.0, "foo", "foo"), | ||
| (2.0, "bar", "baz") | ||
| ).toDF("real", "string1", "string2") | ||
|
|
||
| val n = 1 | ||
| val hasher = new FeatureHasher() | ||
| .setInputCols("real", "string1", "string2") | ||
| .setOutputCol("features") | ||
| .setNumFeatures(n) | ||
|
|
||
| val features = hasher.transform(df).select("features").as[Vector].collect() | ||
| def idx: Any => Int = murmur3FeatureIdx(n) | ||
| // everything should hash into one field | ||
| assert(idx("real") === idx("string1=foo")) | ||
| assert(idx("string1=foo") === idx("string2=foo")) | ||
| assert(idx("string2=foo") === idx("string1=bar")) | ||
| assert(idx("string1=bar") === idx("string2=baz")) | ||
| val expected = Seq( | ||
| Vectors.sparse(n, Seq((idx("string1=foo"), 3.0))), | ||
| Vectors.sparse(n, Seq((idx("string2=bar"), 4.0))) | ||
| ) | ||
| assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 }) | ||
| } | ||
|
|
||
| test("ignores null values in feature hashing") { | ||
| import org.apache.spark.sql.functions._ | ||
|
|
||
| val df = Seq( | ||
| (2.0, "foo", null), | ||
| (3.0, "bar", "baz") | ||
| ).toDF("real", "string1", "string2").select( | ||
| when(col("real") === 3.0, null).otherwise(col("real")).alias("real"), | ||
| col("string1"), | ||
| col("string2") | ||
| ) | ||
|
|
||
| val n = 100 | ||
| val hasher = new FeatureHasher() | ||
| .setInputCols("real", "string1", "string2") | ||
| .setOutputCol("features") | ||
| .setNumFeatures(n) | ||
|
|
||
| val features = hasher.transform(df).select("features").as[Vector].collect() | ||
| def idx: Any => Int = murmur3FeatureIdx(n) | ||
| val expected = Seq( | ||
| Vectors.sparse(n, Seq((idx("real"), 2.0), (idx("string1=foo"), 1.0))), | ||
| Vectors.sparse(n, Seq((idx("string1=bar"), 1.0), (idx("string2=baz"), 1.0))) | ||
| ) | ||
| assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 }) | ||
| } | ||
|
|
||
| test("unicode column names and values") { | ||
| // scalastyle:off nonascii | ||
| val df = Seq((2.0, "中文")).toDF("中文", "unicode") | ||
|
|
||
| val n = 100 | ||
| val hasher = new FeatureHasher() | ||
| .setInputCols("中文", "unicode") | ||
| .setOutputCol("features") | ||
| .setNumFeatures(n) | ||
|
|
||
| val features = hasher.transform(df).select("features").as[Vector].collect() | ||
| def idx: Any => Int = murmur3FeatureIdx(n) | ||
| val expected = Seq( | ||
| Vectors.sparse(n, Seq((idx("中文"), 2.0), (idx("unicode=中文"), 1.0))) | ||
| ) | ||
| assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 }) | ||
| // scalastyle:on nonascii | ||
| } | ||
|
|
||
| test("read/write") { | ||
| val t = new FeatureHasher() | ||
| .setInputCols(Array("myCol1", "myCol2", "myCol3")) | ||
| .setOutputCol("myOutputCol") | ||
| .setNumFeatures(10) | ||
| testDefaultReadWrite(t) | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe
val getDouble...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I read it from here, but never tested it.
https://stackoverflow.com/questions/18887264/what-is-the-difference-between-def-and-val-to-define-a-function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, this is a method not a function - so I don't think it will be faster to do
valin this case?