@@ -30,24 +30,27 @@ import org.apache.spark.ml.util._
3030import org .apache .spark .sql .{DataFrame , Dataset }
3131import org .apache .spark .sql .expressions .UserDefinedFunction
3232import org .apache .spark .sql .functions .{col , lit , udf }
33- import org .apache .spark .sql .types .{DoubleType , NumericType , StructField , StructType }
33+ import org .apache .spark .sql .types .{DoubleType , StructField , StructType }
3434
3535/** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */
3636private [ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
3737 with HasInputCols with HasOutputCols {
3838
3939 /**
40- * Param for how to handle invalid data.
40+ * Param for how to handle invalid data during transform() .
4141 * Options are 'keep' (invalid data presented as an extra categorical feature) or
4242 * 'error' (throw an error).
43+ * Note that this Param is only used during transform; during fitting, invalid data
44+ * will result in an error.
4345 * Default: "error"
4446 * @group param
4547 */
4648 @ Since (" 2.3.0" )
4749 override val handleInvalid : Param [String ] = new Param [String ](this , " handleInvalid" ,
48- " How to handle invalid data " +
50+ " How to handle invalid data during transform(). " +
4951 " Options are 'keep' (invalid data presented as an extra categorical feature) " +
50- " or error (throw an error)." ,
52+ " or error (throw an error). Note that this Param is only used during transform; " +
53+ " during fitting, invalid data will result in an error." ,
5154 ParamValidators .inArray(OneHotEncoderEstimator .supportedHandleInvalids))
5255
5356 setDefault(handleInvalid, OneHotEncoderEstimator .ERROR_INVALID )
@@ -66,10 +69,11 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
6669 def getDropLast : Boolean = $(dropLast)
6770
6871 protected def validateAndTransformSchema (
69- schema : StructType , dropLast : Boolean , keepInvalid : Boolean ): StructType = {
72+ schema : StructType ,
73+ dropLast : Boolean ,
74+ keepInvalid : Boolean ): StructType = {
7075 val inputColNames = $(inputCols)
7176 val outputColNames = $(outputCols)
72- val existingFields = schema.fields
7377
7478 require(inputColNames.length == outputColNames.length,
7579 s " The number of input columns ${inputColNames.length} must be the same as the number of " +
@@ -197,6 +201,10 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat
197201 override def load (path : String ): OneHotEncoderEstimator = super .load(path)
198202}
199203
204+ /**
205+ * @param categorySizes Original number of categories for each feature being encoded.
206+ * The array contains one value for each input column, in order.
207+ */
200208@ Since (" 2.3.0" )
201209class OneHotEncoderModel private [ml] (
202210 @ Since (" 2.3.0" ) override val uid : String ,
@@ -205,60 +213,58 @@ class OneHotEncoderModel private[ml] (
205213
206214 import OneHotEncoderModel ._
207215
208- // Returns the category size for a given index with `dropLast` and `handleInvalid`
216+ // Returns the category size for each index with `dropLast` and `handleInvalid`
209217 // taken into account.
210- private def configedCategorySize ( orgCategorySize : Int , idx : Int ) : Int = {
218+ private def getConfigedCategorySizes : Array [ Int ] = {
211219 val dropLast = getDropLast
212220 val keepInvalid = getHandleInvalid == OneHotEncoderEstimator .KEEP_INVALID
213221
214222 if (! dropLast && keepInvalid) {
215223 // When `handleInvalid` is "keep", an extra category is added as last category
216224 // for invalid data.
217- orgCategorySize + 1
225+ categorySizes.map(_ + 1 )
218226 } else if (dropLast && ! keepInvalid) {
219227 // When `dropLast` is true, the last category is removed.
220- orgCategorySize - 1
228+ categorySizes.map(_ - 1 )
221229 } else {
222230 // When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid
223231 // data is removed. Thus, it is the same as the plain number of categories.
224- orgCategorySize
232+ categorySizes
225233 }
226234 }
227235
228236 private def encoder : UserDefinedFunction = {
229- val oneValue = Array (1.0 )
230- val emptyValues = Array .empty[Double ]
231- val emptyIndices = Array .empty[Int ]
232- val dropLast = getDropLast
233- val handleInvalid = getHandleInvalid
234- val keepInvalid = handleInvalid == OneHotEncoderEstimator .KEEP_INVALID
237+ val keepInvalid = getHandleInvalid == OneHotEncoderEstimator .KEEP_INVALID
238+ val configedSizes = getConfigedCategorySizes
239+ val localCategorySizes = categorySizes
235240
236241 // The udf performed on input data. The first parameter is the input value. The second
237- // parameter is the index of input.
238- udf { (label : Double , idx : Int ) =>
239- val plainNumCategories = categorySizes(idx)
240- val size = configedCategorySize(plainNumCategories, idx)
241-
242- if (label < 0 ) {
243- throw new SparkException (s " Negative value: $label. Input can't be negative. " )
244- } else if (label == size && dropLast && ! keepInvalid) {
245- // When `dropLast` is true and `handleInvalid` is not "keep",
246- // the last category is removed.
247- Vectors .sparse(size, emptyIndices, emptyValues)
248- } else if (label >= plainNumCategories && keepInvalid) {
249- // When `handleInvalid` is "keep", encodes invalid data to last category (and removed
250- // if `dropLast` is true)
251- if (dropLast) {
252- Vectors .sparse(size, emptyIndices, emptyValues)
242+ // parameter is the index in inputCols of the column being encoded.
243+ udf { (label : Double , colIdx : Int ) =>
244+ val origCategorySize = localCategorySizes(colIdx)
245+ // idx: index in vector of the single 1-valued element
246+ val idx = if (label >= 0 && label < origCategorySize) {
247+ label
248+ } else {
249+ if (keepInvalid) {
250+ origCategorySize
253251 } else {
254- Vectors .sparse(size, Array (size - 1 ), oneValue)
252+ if (label < 0 ) {
253+ throw new SparkException (s " Negative value: $label. Input can't be negative. " +
254+ s " To handle invalid values, set Param handleInvalid to " +
255+ s " ${OneHotEncoderEstimator .KEEP_INVALID }" )
256+ } else {
257+ throw new SparkException (s " Unseen value: $label. To handle unseen values, " +
258+ s " set Param handleInvalid to ${OneHotEncoderEstimator .KEEP_INVALID }. " )
259+ }
255260 }
256- } else if (label < plainNumCategories) {
257- Vectors .sparse(size, Array (label.toInt), oneValue)
261+ }
262+
263+ val size = configedSizes(colIdx)
264+ if (idx < size) {
265+ Vectors .sparse(size, Array (idx.toInt), Array (1.0 ))
258266 } else {
259- assert(handleInvalid == OneHotEncoderEstimator .ERROR_INVALID )
260- throw new SparkException (s " Unseen value: $label. To handle unseen values, " +
261- s " set Param handleInvalid to ${OneHotEncoderEstimator .KEEP_INVALID }. " )
267+ Vectors .sparse(size, Array .empty[Int ], Array .empty[Double ])
262268 }
263269 }
264270 }
@@ -282,7 +288,6 @@ class OneHotEncoderModel private[ml] (
282288 @ Since (" 2.3.0" )
283289 override def transformSchema (schema : StructType ): StructType = {
284290 val inputColNames = $(inputCols)
285- val outputColNames = $(outputCols)
286291
287292 require(inputColNames.length == categorySizes.length,
288293 s " The number of input columns ${inputColNames.length} must be the same as the number of " +
@@ -300,6 +305,7 @@ class OneHotEncoderModel private[ml] (
300305 * account. Mismatched numbers will cause exception.
301306 */
302307 private def verifyNumOfValues (schema : StructType ): StructType = {
308+ val configedSizes = getConfigedCategorySizes
303309 $(outputCols).zipWithIndex.foreach { case (outputColName, idx) =>
304310 val inputColName = $(inputCols)(idx)
305311 val attrGroup = AttributeGroup .fromStructField(schema(outputColName))
@@ -308,9 +314,9 @@ class OneHotEncoderModel private[ml] (
308314 // comparing with expected category number with `handleInvalid` and
309315 // `dropLast` taken into account.
310316 if (attrGroup.attributes.nonEmpty) {
311- val numCategories = configedCategorySize(categorySizes(idx), idx)
317+ val numCategories = configedSizes( idx)
312318 require(attrGroup.size == numCategories, " OneHotEncoderModel expected " +
313- s " $numCategories categorical values for input column ${ inputColName} , " +
319+ s " $numCategories categorical values for input column $inputColName, " +
314320 s " but the input column had metadata specifying ${attrGroup.size} values. " )
315321 }
316322 }
@@ -322,7 +328,7 @@ class OneHotEncoderModel private[ml] (
322328 val transformedSchema = transformSchema(dataset.schema, logging = true )
323329 val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator .KEEP_INVALID
324330
325- val encodedColumns = ( 0 until $(inputCols).length) .map { idx =>
331+ val encodedColumns = $(inputCols).indices .map { idx =>
326332 val inputColName = $(inputCols)(idx)
327333 val outputColName = $(outputCols)(idx)
328334
0 commit comments