Skip to content
Closed
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b270bb9
init import.
viirya Apr 8, 2016
b3acdb9
Revert change to Vectors.
viirya Apr 8, 2016
1f58662
Both test UDT through UDTRegistration and SQLUserDefinedType.
viirya Apr 8, 2016
a84d9dd
Check Utils.classIsLoadable.
viirya Apr 8, 2016
e59f1d5
Separate two test suites to prevent interfere with other tests.
viirya Apr 9, 2016
eafdd58
Remove SQLUserDefinedType for MyDVector.
viirya Apr 11, 2016
9fa56d1
Merge remote-tracking branch 'upstream/master' into improve-sql-usertype
viirya Apr 15, 2016
8000a4a
Copy VectorUDT and MatrixUDT into sql.
viirya Apr 15, 2016
1d9b87b
Preload built-in UDTs.
viirya Apr 15, 2016
ba9e910
Add test for pre-loaded VectorUDT.
viirya Apr 15, 2016
2f9472c
Fix scala style.
viirya Apr 15, 2016
5ecfd63
Show log info when trying to register an already registered user class.
viirya Apr 16, 2016
909c01f
Add test for MatrixUDT.
viirya Apr 17, 2016
6cdc446
Only use mllib-local as test dependency of sql-core. Use string-based…
viirya Apr 19, 2016
85d1df0
Move VectorUDT and MatrixUDT tests to mllib.
viirya Apr 19, 2016
744b659
Address comments.
viirya Apr 19, 2016
ace8723
Merge remote-tracking branch 'upstream/master' into improve-sql-usertype
viirya Apr 20, 2016
5aaf9db
Fix test.
viirya Apr 20, 2016
42ff175
Merge remote-tracking branch 'upstream/master' into improve-sql-usertype
viirya Apr 20, 2016
5caa0aa
Add condition check back.
viirya Apr 21, 2016
434b1b4
Merge remote-tracking branch 'upstream/master' into improve-sql-usertype
viirya Apr 22, 2016
023d281
Address comments.
viirya Apr 24, 2016
3917f6b
Fix scala style.
viirya Apr 24, 2016
06bdbc5
Fix scala style.
viirya Apr 24, 2016
de7ea5d
Update pyUDT.
viirya Apr 26, 2016
5d4ba3d
Simplify UDTRegistration.
viirya Apr 28, 2016
45e87d9
Simplify two UDT test suites.
viirya Apr 28, 2016
9ed0f30
Remove unnecessary change.
viirya Apr 28, 2016
1c230ae
Remove empty line.
viirya Apr 28, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.linalg

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._

/**
* User-defined type for [[Matrix]] in [[mllib-local]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.Dataset]].
*/
private[ml] class MatrixUDT extends UserDefinedType[Matrix] {

override def sqlType: StructType = {
// type: 0 = sparse, 1 = dense
// the dense matrix is built by numRows, numCols, values and isTransposed, all of which are
// set as not nullable, except values since in the future, support for binary matrices might
// be added for which values are not needed.
// the sparse matrix needs colPtrs and rowIndices, which are set as
// null, while building the dense matrix.
StructType(Seq(
StructField("type", ByteType, nullable = false),
StructField("numRows", IntegerType, nullable = false),
StructField("numCols", IntegerType, nullable = false),
StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true),
StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true),
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true),
StructField("isTransposed", BooleanType, nullable = false)
))
}

override def serialize(obj: Matrix): InternalRow = {
val row = new GenericMutableRow(7)
obj match {
case sm: SparseMatrix =>
row.setByte(0, 0)
row.setInt(1, sm.numRows)
row.setInt(2, sm.numCols)
row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any])))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not part of this PR, but we might need to fix this. It seems that it boxes primitive arrays. I created https://issues.apache.org/jira/browse/SPARK-14850 to track the issue. cc: @cloud-fan

row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any])))
row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any])))
row.setBoolean(6, sm.isTransposed)

case dm: DenseMatrix =>
row.setByte(0, 1)
row.setInt(1, dm.numRows)
row.setInt(2, dm.numCols)
row.setNullAt(3)
row.setNullAt(4)
row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any])))
row.setBoolean(6, dm.isTransposed)
}
row
}

override def deserialize(datum: Any): Matrix = {
datum match {
case row: InternalRow =>
require(row.numFields == 7,
s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7")
val tpe = row.getByte(0)
val numRows = row.getInt(1)
val numCols = row.getInt(2)
val values = row.getArray(5).toDoubleArray()
val isTransposed = row.getBoolean(6)
tpe match {
case 0 =>
val colPtrs = row.getArray(3).toIntArray()
val rowIndices = row.getArray(4).toIntArray()
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
case 1 =>
new DenseMatrix(numRows, numCols, values, isTransposed)
}
}
}

override def userClass: Class[Matrix] = classOf[Matrix]

override def equals(o: Any): Boolean = {
o match {
case v: MatrixUDT => true
case _ => false
}
}

// see [SPARK-8647], this achieves the needed constant hash code without constant no.
override def hashCode(): Int = classOf[MatrixUDT].getName.hashCode()

override def typeName: String = "matrix"

override def pyUDT: String = "pyspark.ml.linalg.MatrixUDT"

private[spark] override def asNullable: MatrixUDT = this
}
98 changes: 98 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.linalg

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._

/**
* User-defined type for [[Vector]] in [[mllib-local]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.Dataset]].
*/
private[ml] class VectorUDT extends UserDefinedType[Vector] {

override def sqlType: StructType = {
// type: 0 = sparse, 1 = dense
// We only use "values" for dense vectors, and "size", "indices", and "values" for sparse
// vectors. The "values" field is nullable because we might want to add binary vectors later,
// which uses "size" and "indices", but not "values".
StructType(Seq(
StructField("type", ByteType, nullable = false),
StructField("size", IntegerType, nullable = true),
StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true),
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
}

override def serialize(obj: Vector): InternalRow = {
obj match {
case SparseVector(size, indices, values) =>
val row = new GenericMutableRow(4)
row.setByte(0, 0)
row.setInt(1, size)
row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any])))
row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
row
case DenseVector(values) =>
val row = new GenericMutableRow(4)
row.setByte(0, 1)
row.setNullAt(1)
row.setNullAt(2)
row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
row
}
}

override def deserialize(datum: Any): Vector = {
datum match {
case row: InternalRow =>
require(row.numFields == 4,
s"VectorUDT.deserialize given row with length ${row.numFields} but requires length == 4")
val tpe = row.getByte(0)
tpe match {
case 0 =>
val size = row.getInt(1)
val indices = row.getArray(2).toIntArray()
val values = row.getArray(3).toDoubleArray()
new SparseVector(size, indices, values)
case 1 =>
val values = row.getArray(3).toDoubleArray()
new DenseVector(values)
}
}
}

override def pyUDT: String = "pyspark.ml.linalg.VectorUDT"

override def userClass: Class[Vector] = classOf[Vector]

override def equals(o: Any): Boolean = {
o match {
case v: VectorUDT => true
case _ => false
}
}

// see [SPARK-8647], this achieves the needed constant hash code without constant no.
override def hashCode(): Int = classOf[VectorUDT].getName.hashCode()

override def typeName: String = "vector"

private[spark] override def asNullable: VectorUDT = this
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.linalg

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._

class MatrixUDTSuite extends SparkFunSuite {

test("preloaded MatrixUDT") {
val dm1 = new DenseMatrix(2, 2, Array(0.9, 1.2, 2.3, 9.8))
val dm2 = new DenseMatrix(3, 2, Array(0.0, 1.21, 2.3, 9.8, 9.0, 0.0))
val dm3 = new DenseMatrix(0, 0, Array())
val sm1 = dm1.toSparse
val sm2 = dm2.toSparse
val sm3 = dm3.toSparse

for (m <- Seq(dm1, dm2, dm3, sm1, sm2, sm3)) {
val udt = UDTRegistration.getUDTFor(m.getClass.getName).get.newInstance()
.asInstanceOf[MatrixUDT]
assert(m === udt.deserialize(udt.serialize(m)))
assert(udt.typeName == "matrix")
assert(udt.simpleString == "matrix")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.linalg

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._

class VectorUDTSuite extends SparkFunSuite {

test("preloaded VectorUDT") {
val dv1 = Vectors.dense(Array.empty[Double])
val dv2 = Vectors.dense(1.0, 2.0)
val sv1 = Vectors.sparse(2, Array.empty, Array.empty)
val sv2 = Vectors.sparse(2, Array(1), Array(2.0))

for (v <- Seq(dv1, dv2, sv1, sv2)) {
val udt = UDTRegistration.getUDTFor(v.getClass.getName).get.newInstance()
.asInstanceOf[VectorUDT]
assert(v === udt.deserialize(udt.serialize(v)))
assert(udt.typeName == "vector")
assert(udt.simpleString == "vector")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
Expand Down Expand Up @@ -381,6 +382,15 @@ object ScalaReflection extends ScalaReflection {
Nil,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)

case t if UDTRegistration.exists(getClassNameFromType(t)) =>
val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
.asInstanceOf[UserDefinedType[_]]
val obj = NewInstance(
udt.getClass,
Nil,
dataType = ObjectType(udt.getClass))
Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
}
}

Expand Down Expand Up @@ -595,6 +605,15 @@ object ScalaReflection extends ScalaReflection {
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)

case t if UDTRegistration.exists(getClassNameFromType(t)) =>
val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
.asInstanceOf[UserDefinedType[_]]
val obj = NewInstance(
udt.getClass,
Nil,
dataType = ObjectType(udt.getClass))
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)

case other =>
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
Expand Down Expand Up @@ -663,6 +682,10 @@ object ScalaReflection extends ScalaReflection {
case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
Schema(udt, nullable = true)
case t if UDTRegistration.exists(getClassNameFromType(t)) =>
val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
.asInstanceOf[UserDefinedType[_]]
Schema(udt, nullable = true)
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.encoders
import scala.collection.Map
import scala.reflect.ClassTag

import org.apache.spark.SparkException
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
Expand Down Expand Up @@ -55,10 +56,19 @@ object RowEncoder {
case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType)

case udt: UserDefinedType[_] =>
val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
val udtClass: Class[_] = if (annotation != null) {
annotation.udt()
} else {
UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse {
throw new SparkException(s"${udt.userClass.getName} is not annotated with " +
"SQLUserDefinedType nor registered with UDTRegistration.}")
}
}
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
udtClass,
Nil,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
dataType = ObjectType(udtClass), false)
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)

case TimestampType =>
Expand Down Expand Up @@ -187,10 +197,19 @@ object RowEncoder {
FloatType | DoubleType | BinaryType | CalendarIntervalType => input

case udt: UserDefinedType[_] =>
val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
val udtClass: Class[_] = if (annotation != null) {
annotation.udt()
} else {
UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse {
throw new SparkException(s"${udt.userClass.getName} is not annotated with " +
"SQLUserDefinedType nor registered with UDTRegistration.}")
}
}
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
udtClass,
Nil,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
dataType = ObjectType(udtClass))
Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)

case TimestampType =>
Expand Down
Loading