Skip to content

Commit 47c1d56

Browse files
dusenberrymwjkbradley
authored andcommitted
[SPARK-7426] [MLLIB] [ML] Updated Attribute.fromStructField to allow any NumericType.
Updated `Attribute.fromStructField` to allow any `NumericType`, rather than just `DoubleType`, and added unit tests for a few of the other NumericTypes. Author: Mike Dusenberry <[email protected]> Closes apache#6540 from dusenberrymw/SPARK-7426_AttributeFactory.fromStructField_Should_Allow_NumericTypes and squashes the following commits: 87fecb3 [Mike Dusenberry] Updated Attribute.fromStructField to allow any NumericType, rather than just DoubleType, and added unit tests for a few of the other NumericTypes.
1 parent a189442 commit 47c1d56

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.ml.attribute
2020
import scala.annotation.varargs
2121

2222
import org.apache.spark.annotation.DeveloperApi
23-
import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField}
23+
import org.apache.spark.sql.types.{DoubleType, NumericType, Metadata, MetadataBuilder, StructField}
2424

2525
/**
2626
* :: DeveloperApi ::
@@ -127,7 +127,7 @@ private[attribute] trait AttributeFactory {
127127
* Creates an [[Attribute]] from a [[StructField]] instance.
128128
*/
129129
def fromStructField(field: StructField): Attribute = {
130-
require(field.dataType == DoubleType)
130+
require(field.dataType.isInstanceOf[NumericType])
131131
val metadata = field.metadata
132132
val mlAttr = AttributeKeys.ML_ATTR
133133
if (metadata.contains(mlAttr)) {

mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,5 +215,10 @@ class AttributeSuite extends SparkFunSuite {
215215
assert(Attribute.fromStructField(fldWithoutMeta) == UnresolvedAttribute)
216216
val fldWithMeta = new StructField("x", DoubleType, false, metadata)
217217
assert(Attribute.fromStructField(fldWithMeta).isNumeric)
218+
// Attribute.fromStructField should accept any NumericType, not just DoubleType
219+
val longFldWithMeta = new StructField("x", LongType, false, metadata)
220+
assert(Attribute.fromStructField(longFldWithMeta).isNumeric)
221+
val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata)
222+
assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric)
218223
}
219224
}

0 commit comments

Comments
 (0)