1818package org .apache .spark .ml
1919
2020import scala .annotation .varargs
21- import scala .reflect .runtime .universe .TypeTag
2221
2322import org .apache .spark .Logging
2423import org .apache .spark .annotation .AlphaComponent
2524import org .apache .spark .ml .param ._
2625import org .apache .spark .sql .SchemaRDD
2726import org .apache .spark .sql .api .java .JavaSchemaRDD
28- import org .apache .spark .sql .catalyst .ScalaReflection
2927import org .apache .spark .sql .catalyst .analysis .Star
30- import org .apache .spark .sql .catalyst .dsl . _
28+ import org .apache .spark .sql .catalyst .expressions . ScalaUdf
3129import org .apache .spark .sql .catalyst .types ._
3230
3331/**
@@ -86,7 +84,7 @@ abstract class Transformer extends PipelineStage with Params {
8684 * Abstract class for transformers that take one input column, apply transformation, and output the
8785 * result as a new column.
8886 */
89- private [ml] abstract class UnaryTransformer [IN , OUT : TypeTag , T <: UnaryTransformer [IN , OUT , T ]]
87+ private [ml] abstract class UnaryTransformer [IN , OUT , T <: UnaryTransformer [IN , OUT , T ]]
9088 extends Transformer with HasInputCol with HasOutputCol with Logging {
9189
9290 def setInputCol (value : String ): T = set(inputCol, value).asInstanceOf [T ]
@@ -99,6 +97,11 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor
9997 */
10098 protected def createTransformFunc (paramMap : ParamMap ): IN => OUT
10199
100+ /**
101+ * Returns the data type of the output column.
102+ */
103+ protected def outputDataType : DataType
104+
102105 /**
103106 * Validates the input type. Throw an exception if it is invalid.
104107 */
@@ -111,17 +114,16 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor
111114 if (schema.fieldNames.contains(map(outputCol))) {
112115 throw new IllegalArgumentException (s " Output column ${map(outputCol)} already exists. " )
113116 }
114- val output = ScalaReflection .schemaFor[OUT ]
115117 val outputFields = schema.fields :+
116- StructField (map(outputCol), output.dataType, output.nullable )
118+ StructField (map(outputCol), outputDataType, ! outputDataType.isPrimitive )
117119 StructType (outputFields)
118120 }
119121
120122 override def transform (dataset : SchemaRDD , paramMap : ParamMap ): SchemaRDD = {
121123 transformSchema(dataset.schema, paramMap, logging = true )
122124 import dataset .sqlContext ._
123125 val map = this .paramMap ++ paramMap
124- val udf = this .createTransformFunc(map)
125- dataset.select(Star (None ), udf.call(map(inputCol).attr) as map (outputCol))
126+ val udf = ScalaUdf ( this .createTransformFunc(map), outputDataType, Seq (map(inputCol).attr) )
127+ dataset.select(Star (None ), udf as map (outputCol))
126128 }
127129}
0 commit comments