Skip to content

Commit ea34dc6

Browse files
committed
Merge pull request #4 from mengxr/ml-package-docs
replace TypeTag with explicit datatype
2 parents 41ad9b1 + 3b83ec0 commit ea34dc6

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

mllib/src/main/scala/org/apache/spark/ml/Transformer.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,14 @@
1818
package org.apache.spark.ml
1919

2020
import scala.annotation.varargs
21-
import scala.reflect.runtime.universe.TypeTag
2221

2322
import org.apache.spark.Logging
2423
import org.apache.spark.annotation.AlphaComponent
2524
import org.apache.spark.ml.param._
2625
import org.apache.spark.sql.SchemaRDD
2726
import org.apache.spark.sql.api.java.JavaSchemaRDD
28-
import org.apache.spark.sql.catalyst.ScalaReflection
2927
import org.apache.spark.sql.catalyst.analysis.Star
30-
import org.apache.spark.sql.catalyst.dsl._
28+
import org.apache.spark.sql.catalyst.expressions.ScalaUdf
3129
import org.apache.spark.sql.catalyst.types._
3230

3331
/**
@@ -86,7 +84,7 @@ abstract class Transformer extends PipelineStage with Params {
8684
* Abstract class for transformers that take one input column, apply transformation, and output the
8785
* result as a new column.
8886
*/
89-
private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]]
87+
private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
9088
extends Transformer with HasInputCol with HasOutputCol with Logging {
9189

9290
def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T]
@@ -99,6 +97,11 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor
9997
*/
10098
protected def createTransformFunc(paramMap: ParamMap): IN => OUT
10199

100+
/**
101+
* Returns the data type of the output column.
102+
*/
103+
protected def outputDataType: DataType
104+
102105
/**
103106
* Validates the input type. Throw an exception if it is invalid.
104107
*/
@@ -111,17 +114,16 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor
111114
if (schema.fieldNames.contains(map(outputCol))) {
112115
throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.")
113116
}
114-
val output = ScalaReflection.schemaFor[OUT]
115117
val outputFields = schema.fields :+
116-
StructField(map(outputCol), output.dataType, output.nullable)
118+
StructField(map(outputCol), outputDataType, !outputDataType.isPrimitive)
117119
StructType(outputFields)
118120
}
119121

120122
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
121123
transformSchema(dataset.schema, paramMap, logging = true)
122124
import dataset.sqlContext._
123125
val map = this.paramMap ++ paramMap
124-
val udf = this.createTransformFunc(map)
125-
dataset.select(Star(None), udf.call(map(inputCol).attr) as map(outputCol))
126+
val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr))
127+
dataset.select(Star(None), udf as map(outputCol))
126128
}
127129
}

mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml.UnaryTransformer
2222
import org.apache.spark.ml.param.{IntParam, ParamMap}
2323
import org.apache.spark.mllib.feature
24-
import org.apache.spark.mllib.linalg.Vector
24+
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
25+
import org.apache.spark.sql.catalyst.types.DataType
2526

2627
/**
2728
* :: AlphaComponent ::
@@ -39,4 +40,6 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
3940
val hashingTF = new feature.HashingTF(paramMap(numFeatures))
4041
hashingTF.transform
4142
}
43+
44+
override protected def outputDataType: DataType = new VectorUDT()
4245
}

mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
2020
import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml.UnaryTransformer
2222
import org.apache.spark.ml.param.ParamMap
23-
import org.apache.spark.sql.{DataType, StringType}
23+
import org.apache.spark.sql.{DataType, StringType, ArrayType}
2424

2525
/**
2626
* :: AlphaComponent ::
@@ -36,4 +36,6 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
3636
protected override def validateInputType(inputType: DataType): Unit = {
3737
require(inputType == StringType, s"Input type must be string type but got $inputType.")
3838
}
39+
40+
override protected def outputDataType: DataType = new ArrayType(StringType, false)
3941
}

0 commit comments

Comments
 (0)