Skip to content

Commit 930b90a

Browse files
committed
[SPARK-13030][ML] Follow-up cleanups for OneHotEncoderEstimator
## What changes were proposed in this pull request? Follow-up cleanups for the OneHotEncoderEstimator PR. See some discussion in the original PR: #19527 or read below for what this PR includes: * configedCategorySize: I reverted this to return an Array. I realized the original setup (which I had recommended in the original PR) caused the whole model to be serialized in the UDF. * encoder: I reorganized the logic to show what I meant in the comment in the previous PR. I think it's simpler but am open to suggestions. I also made some small style cleanups based on IntelliJ warnings. ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley <[email protected]> Closes #20132 from jkbradley/viirya-SPARK-13030.
1 parent c0b7424 commit 930b90a

File tree

1 file changed

+49
-43
lines changed

1 file changed

+49
-43
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,27 @@ import org.apache.spark.ml.util._
3030
import org.apache.spark.sql.{DataFrame, Dataset}
3131
import org.apache.spark.sql.expressions.UserDefinedFunction
3232
import org.apache.spark.sql.functions.{col, lit, udf}
33-
import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType}
33+
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
3434

3535
/** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */
3636
private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
3737
with HasInputCols with HasOutputCols {
3838

3939
/**
40-
* Param for how to handle invalid data.
40+
* Param for how to handle invalid data during transform().
4141
* Options are 'keep' (invalid data presented as an extra categorical feature) or
4242
* 'error' (throw an error).
43+
* Note that this Param is only used during transform; during fitting, invalid data
44+
* will result in an error.
4345
* Default: "error"
4446
* @group param
4547
*/
4648
@Since("2.3.0")
4749
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
48-
"How to handle invalid data " +
50+
"How to handle invalid data during transform(). " +
4951
"Options are 'keep' (invalid data presented as an extra categorical feature) " +
50-
"or error (throw an error).",
52+
"or error (throw an error). Note that this Param is only used during transform; " +
53+
"during fitting, invalid data will result in an error.",
5154
ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids))
5255

5356
setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID)
@@ -66,10 +69,11 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
6669
def getDropLast: Boolean = $(dropLast)
6770

6871
protected def validateAndTransformSchema(
69-
schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = {
72+
schema: StructType,
73+
dropLast: Boolean,
74+
keepInvalid: Boolean): StructType = {
7075
val inputColNames = $(inputCols)
7176
val outputColNames = $(outputCols)
72-
val existingFields = schema.fields
7377

7478
require(inputColNames.length == outputColNames.length,
7579
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
@@ -197,6 +201,10 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat
197201
override def load(path: String): OneHotEncoderEstimator = super.load(path)
198202
}
199203

204+
/**
205+
* @param categorySizes Original number of categories for each feature being encoded.
206+
* The array contains one value for each input column, in order.
207+
*/
200208
@Since("2.3.0")
201209
class OneHotEncoderModel private[ml] (
202210
@Since("2.3.0") override val uid: String,
@@ -205,60 +213,58 @@ class OneHotEncoderModel private[ml] (
205213

206214
import OneHotEncoderModel._
207215

208-
// Returns the category size for a given index with `dropLast` and `handleInvalid`
216+
// Returns the category size for each index with `dropLast` and `handleInvalid`
209217
// taken into account.
210-
private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = {
218+
private def getConfigedCategorySizes: Array[Int] = {
211219
val dropLast = getDropLast
212220
val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID
213221

214222
if (!dropLast && keepInvalid) {
215223
// When `handleInvalid` is "keep", an extra category is added as last category
216224
// for invalid data.
217-
orgCategorySize + 1
225+
categorySizes.map(_ + 1)
218226
} else if (dropLast && !keepInvalid) {
219227
// When `dropLast` is true, the last category is removed.
220-
orgCategorySize - 1
228+
categorySizes.map(_ - 1)
221229
} else {
222230
// When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid
223231
// data is removed. Thus, it is the same as the plain number of categories.
224-
orgCategorySize
232+
categorySizes
225233
}
226234
}
227235

228236
private def encoder: UserDefinedFunction = {
229-
val oneValue = Array(1.0)
230-
val emptyValues = Array.empty[Double]
231-
val emptyIndices = Array.empty[Int]
232-
val dropLast = getDropLast
233-
val handleInvalid = getHandleInvalid
234-
val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID
237+
val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID
238+
val configedSizes = getConfigedCategorySizes
239+
val localCategorySizes = categorySizes
235240

236241
// The udf performed on input data. The first parameter is the input value. The second
237-
// parameter is the index of input.
238-
udf { (label: Double, idx: Int) =>
239-
val plainNumCategories = categorySizes(idx)
240-
val size = configedCategorySize(plainNumCategories, idx)
241-
242-
if (label < 0) {
243-
throw new SparkException(s"Negative value: $label. Input can't be negative.")
244-
} else if (label == size && dropLast && !keepInvalid) {
245-
// When `dropLast` is true and `handleInvalid` is not "keep",
246-
// the last category is removed.
247-
Vectors.sparse(size, emptyIndices, emptyValues)
248-
} else if (label >= plainNumCategories && keepInvalid) {
249-
// When `handleInvalid` is "keep", encodes invalid data to last category (and removed
250-
// if `dropLast` is true)
251-
if (dropLast) {
252-
Vectors.sparse(size, emptyIndices, emptyValues)
242+
// parameter is the index in inputCols of the column being encoded.
243+
udf { (label: Double, colIdx: Int) =>
244+
val origCategorySize = localCategorySizes(colIdx)
245+
// idx: index in vector of the single 1-valued element
246+
val idx = if (label >= 0 && label < origCategorySize) {
247+
label
248+
} else {
249+
if (keepInvalid) {
250+
origCategorySize
253251
} else {
254-
Vectors.sparse(size, Array(size - 1), oneValue)
252+
if (label < 0) {
253+
throw new SparkException(s"Negative value: $label. Input can't be negative. " +
254+
s"To handle invalid values, set Param handleInvalid to " +
255+
s"${OneHotEncoderEstimator.KEEP_INVALID}")
256+
} else {
257+
throw new SparkException(s"Unseen value: $label. To handle unseen values, " +
258+
s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.")
259+
}
255260
}
256-
} else if (label < plainNumCategories) {
257-
Vectors.sparse(size, Array(label.toInt), oneValue)
261+
}
262+
263+
val size = configedSizes(colIdx)
264+
if (idx < size) {
265+
Vectors.sparse(size, Array(idx.toInt), Array(1.0))
258266
} else {
259-
assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID)
260-
throw new SparkException(s"Unseen value: $label. To handle unseen values, " +
261-
s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.")
267+
Vectors.sparse(size, Array.empty[Int], Array.empty[Double])
262268
}
263269
}
264270
}
@@ -282,7 +288,6 @@ class OneHotEncoderModel private[ml] (
282288
@Since("2.3.0")
283289
override def transformSchema(schema: StructType): StructType = {
284290
val inputColNames = $(inputCols)
285-
val outputColNames = $(outputCols)
286291

287292
require(inputColNames.length == categorySizes.length,
288293
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
@@ -300,6 +305,7 @@ class OneHotEncoderModel private[ml] (
300305
* account. Mismatched numbers will cause exception.
301306
*/
302307
private def verifyNumOfValues(schema: StructType): StructType = {
308+
val configedSizes = getConfigedCategorySizes
303309
$(outputCols).zipWithIndex.foreach { case (outputColName, idx) =>
304310
val inputColName = $(inputCols)(idx)
305311
val attrGroup = AttributeGroup.fromStructField(schema(outputColName))
@@ -308,9 +314,9 @@ class OneHotEncoderModel private[ml] (
308314
// comparing with expected category number with `handleInvalid` and
309315
// `dropLast` taken into account.
310316
if (attrGroup.attributes.nonEmpty) {
311-
val numCategories = configedCategorySize(categorySizes(idx), idx)
317+
val numCategories = configedSizes(idx)
312318
require(attrGroup.size == numCategories, "OneHotEncoderModel expected " +
313-
s"$numCategories categorical values for input column ${inputColName}, " +
319+
s"$numCategories categorical values for input column $inputColName, " +
314320
s"but the input column had metadata specifying ${attrGroup.size} values.")
315321
}
316322
}
@@ -322,7 +328,7 @@ class OneHotEncoderModel private[ml] (
322328
val transformedSchema = transformSchema(dataset.schema, logging = true)
323329
val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID
324330

325-
val encodedColumns = (0 until $(inputCols).length).map { idx =>
331+
val encodedColumns = $(inputCols).indices.map { idx =>
326332
val inputColName = $(inputCols)(idx)
327333
val outputColName = $(outputCols)(idx)
328334

0 commit comments

Comments
 (0)