diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala new file mode 100644 index 000000000000..53f4d5597146 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.linalg + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.types._ + +/** + * User-defined type for [[Matrix]] in [[mllib-local]] which allows easy interaction with SQL + * via [[org.apache.spark.sql.Dataset]]. + */ +private[ml] class MatrixUDT extends UserDefinedType[Matrix] { + + override def sqlType: StructType = { + // type: 0 = sparse, 1 = dense + // the dense matrix is built by numRows, numCols, values and isTransposed, all of which are + // set as not nullable, except values since in the future, support for binary matrices might + // be added for which values are not needed. + // the sparse matrix needs colPtrs and rowIndices, which are set as + // null, while building the dense matrix. + StructType(Seq( + StructField("type", ByteType, nullable = false), + StructField("numRows", IntegerType, nullable = false), + StructField("numCols", IntegerType, nullable = false), + StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true), + StructField("isTransposed", BooleanType, nullable = false) + )) + } + + override def serialize(obj: Matrix): InternalRow = { + val row = new GenericMutableRow(7) + obj match { + case sm: SparseMatrix => + row.setByte(0, 0) + row.setInt(1, sm.numRows) + row.setInt(2, sm.numCols) + row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any]))) + row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any]))) + row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any]))) + row.setBoolean(6, sm.isTransposed) + + case dm: DenseMatrix => + row.setByte(0, 1) + row.setInt(1, dm.numRows) + row.setInt(2, dm.numCols) + row.setNullAt(3) + row.setNullAt(4) + row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any]))) + row.setBoolean(6, dm.isTransposed) + } + row + } + + override def deserialize(datum: Any): Matrix = { + datum match { + case row: InternalRow => + require(row.numFields == 7, + s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7") + val tpe = row.getByte(0) + val numRows = row.getInt(1) + val numCols = row.getInt(2) + val values = row.getArray(5).toDoubleArray() + val isTransposed = row.getBoolean(6) + tpe match { + case 0 => + val colPtrs = row.getArray(3).toIntArray() + val rowIndices = row.getArray(4).toIntArray() + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) + case 1 => + new DenseMatrix(numRows, numCols, values, isTransposed) + } + } + } + + override def userClass: Class[Matrix] = classOf[Matrix] + + override def equals(o: Any): Boolean = { + o match { + case v: MatrixUDT => true + case _ => false + } + } + + // see [SPARK-8647], this achieves the needed constant hash code without constant no. + override def hashCode(): Int = classOf[MatrixUDT].getName.hashCode() + + override def typeName: String = "matrix" + + override def pyUDT: String = "pyspark.ml.linalg.MatrixUDT" + + private[spark] override def asNullable: MatrixUDT = this +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala new file mode 100644 index 000000000000..fe93a12d065e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.linalg + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.types._ + +/** + * User-defined type for [[Vector]] in [[mllib-local]] which allows easy interaction with SQL + * via [[org.apache.spark.sql.Dataset]]. + */ +private[ml] class VectorUDT extends UserDefinedType[Vector] { + + override def sqlType: StructType = { + // type: 0 = sparse, 1 = dense + // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse + // vectors. The "values" field is nullable because we might want to add binary vectors later, + // which uses "size" and "indices", but not "values". + StructType(Seq( + StructField("type", ByteType, nullable = false), + StructField("size", IntegerType, nullable = true), + StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) + } + + override def serialize(obj: Vector): InternalRow = { + obj match { + case SparseVector(size, indices, values) => + val row = new GenericMutableRow(4) + row.setByte(0, 0) + row.setInt(1, size) + row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any]))) + row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) + row + case DenseVector(values) => + val row = new GenericMutableRow(4) + row.setByte(0, 1) + row.setNullAt(1) + row.setNullAt(2) + row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) + row + } + } + + override def deserialize(datum: Any): Vector = { + datum match { + case row: InternalRow => + require(row.numFields == 4, + s"VectorUDT.deserialize given row with length ${row.numFields} but requires length == 4") + val tpe = row.getByte(0) + tpe match { + case 0 => + val size = row.getInt(1) + val indices = row.getArray(2).toIntArray() + val values = row.getArray(3).toDoubleArray() + new SparseVector(size, indices, values) + case 1 => + val values = row.getArray(3).toDoubleArray() + new DenseVector(values) + } + } + } + + override def pyUDT: String = "pyspark.ml.linalg.VectorUDT" + + override def userClass: Class[Vector] = classOf[Vector] + + override def equals(o: Any): Boolean = { + o match { + case v: VectorUDT => true + case _ => false + } + } + + // see [SPARK-8647], this achieves the needed constant hash code without constant no. + override def hashCode(): Int = classOf[VectorUDT].getName.hashCode() + + override def typeName: String = "vector" + + private[spark] override def asNullable: VectorUDT = this +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala new file mode 100644 index 000000000000..bdceba7887ca --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.linalg + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class MatrixUDTSuite extends SparkFunSuite { + + test("preloaded MatrixUDT") { + val dm1 = new DenseMatrix(2, 2, Array(0.9, 1.2, 2.3, 9.8)) + val dm2 = new DenseMatrix(3, 2, Array(0.0, 1.21, 2.3, 9.8, 9.0, 0.0)) + val dm3 = new DenseMatrix(0, 0, Array()) + val sm1 = dm1.toSparse + val sm2 = dm2.toSparse + val sm3 = dm3.toSparse + + for (m <- Seq(dm1, dm2, dm3, sm1, sm2, sm3)) { + val udt = UDTRegistration.getUDTFor(m.getClass.getName).get.newInstance() + .asInstanceOf[MatrixUDT] + assert(m === udt.deserialize(udt.serialize(m))) + assert(udt.typeName == "matrix") + assert(udt.simpleString == "matrix") + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala new file mode 100644 index 000000000000..6d01d8f2828e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.linalg + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class VectorUDTSuite extends SparkFunSuite { + + test("preloaded VectorUDT") { + val dv1 = Vectors.dense(Array.empty[Double]) + val dv2 = Vectors.dense(1.0, 2.0) + val sv1 = Vectors.sparse(2, Array.empty, Array.empty) + val sv2 = Vectors.sparse(2, Array(1), Array(2.0)) + + for (v <- Seq(dv1, dv2, sv1, sv2)) { + val udt = UDTRegistration.getUDTFor(v.getClass.getName).get.newInstance() + .asInstanceOf[VectorUDT] + assert(v === udt.deserialize(udt.serialize(v))) + assert(udt.typeName == "vector") + assert(udt.simpleString == "vector") + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index bd723135b510..5e400bfc8db0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} @@ -381,6 +382,15 @@ object ScalaReflection extends ScalaReflection { Nil, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) + + case t if UDTRegistration.exists(getClassNameFromType(t)) => + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() + .asInstanceOf[UserDefinedType[_]] + val obj = NewInstance( + udt.getClass, + Nil, + dataType = ObjectType(udt.getClass)) + Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) } } @@ -595,6 +605,15 @@ object ScalaReflection extends ScalaReflection { dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + case t if UDTRegistration.exists(getClassNameFromType(t)) => + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() + .asInstanceOf[UserDefinedType[_]] + val obj = NewInstance( + udt.getClass, + Nil, + dataType = ObjectType(udt.getClass)) + Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + case other => throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) @@ -663,6 +682,10 @@ object ScalaReflection extends ScalaReflection { case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() Schema(udt, nullable = true) + case t if UDTRegistration.exists(getClassNameFromType(t)) => + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() + .asInstanceOf[UserDefinedType[_]] + Schema(udt, nullable = true) case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t Schema(schemaFor(optType).dataType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index a8397aa5e5c2..44e135cbf835 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.encoders import scala.collection.Map import scala.reflect.ClassTag +import org.apache.spark.SparkException import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} @@ -55,10 +56,19 @@ object RowEncoder { case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType) case udt: UserDefinedType[_] => + val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) + val udtClass: Class[_] = if (annotation != null) { + annotation.udt() + } else { + UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse { + throw new SparkException(s"${udt.userClass.getName} is not annotated with " + + "SQLUserDefinedType nor registered with UDTRegistration.}") + } + } val obj = NewInstance( - udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + udtClass, Nil, - dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + dataType = ObjectType(udtClass), false) Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) case TimestampType => @@ -187,10 +197,19 @@ object RowEncoder { FloatType | DoubleType | BinaryType | CalendarIntervalType => input case udt: UserDefinedType[_] => + val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) + val udtClass: Class[_] = if (annotation != null) { + annotation.udt() + } else { + UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse { + throw new SparkException(s"${udt.userClass.getName} is not annotated with " + + "SQLUserDefinedType nor registered with UDTRegistration.}") + } + } val obj = NewInstance( - udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + udtClass, Nil, - dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + dataType = ObjectType(udtClass)) Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) case TimestampType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala new file mode 100644 index 000000000000..0f24e51ed2b7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.collection.mutable + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * This object keeps the mappings between user classes and their User Defined Types (UDTs). + * Previously we use the annotation `SQLUserDefinedType` to register UDTs for user classes. + * However, by doing this, we add SparkSQL dependency on user classes. This object provides + * alterntive approach to register UDTs for user classes. + */ +private[spark] +object UDTRegistration extends Serializable with Logging { + + /** The mapping between the Class between UserDefinedType and user classes. */ + private lazy val udtMap: mutable.Map[String, String] = mutable.Map( + ("org.apache.spark.ml.linalg.Vector", "org.apache.spark.ml.linalg.VectorUDT"), + ("org.apache.spark.ml.linalg.DenseVector", "org.apache.spark.ml.linalg.VectorUDT"), + ("org.apache.spark.ml.linalg.SparseVector", "org.apache.spark.ml.linalg.VectorUDT"), + ("org.apache.spark.ml.linalg.Matrix", "org.apache.spark.ml.linalg.MatrixUDT"), + ("org.apache.spark.ml.linalg.DenseMatrix", "org.apache.spark.ml.linalg.MatrixUDT"), + ("org.apache.spark.ml.linalg.SparseMatrix", "org.apache.spark.ml.linalg.MatrixUDT")) + + /** + * Queries if a given user class is already registered or not. + * @param userClassName the name of user class + * @return boolean value indicates if the given user class is registered or not + */ + def exists(userClassName: String): Boolean = udtMap.contains(userClassName) + + /** + * Registers an UserDefinedType to an user class. If the user class is already registered + * with another UserDefinedType, warning log message will be shown. + * @param userClass the name of user class + * @param udtClass the name of UserDefinedType class for the given userClass + */ + def register(userClass: String, udtClass: String): Unit = { + if (udtMap.contains(userClass)) { + logWarning(s"Cannot register UDT for ${userClass}, which is already registered.") + } else { + // When register UDT with class name, we can't check if the UDT class is an UserDefinedType, + // or not. The check is deferred. + udtMap += ((userClass, udtClass)) + } + } + + /** + * Returns the Class of UserDefinedType for the name of a given user class. + * @param userClass class name of user class + * @return Option value of the Class object of UserDefinedType + */ + def getUDTFor(userClass: String): Option[Class[_]] = { + udtMap.get(userClass).map { udtClassName => + if (Utils.classIsLoadable(udtClassName)) { + val udtClass = Utils.classForName(udtClassName) + if (classOf[UserDefinedType[_]].isAssignableFrom(udtClass)) { + udtClass + } else { + throw new SparkException( + s"${udtClass.getName} is not an UserDefinedType. Please make sure registering " + + s"an UserDefinedType for ${userClass}") + } + } else { + throw new SparkException( + s"Can not load in UserDefinedType ${udtClassName} for user class ${userClass}.") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDTRegistrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDTRegistrationSuite.scala new file mode 100644 index 000000000000..d61ede780a74 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDTRegistrationSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.types._ + +private[sql] class TestUserClass { +} + +private[sql] class TestUserClass2 { +} + +private[sql] class TestUserClass3 { +} + +private[sql] class NonUserDefinedType { +} + +private[sql] class TestUserClassUDT extends UserDefinedType[TestUserClass] { + + override def sqlType: DataType = IntegerType + override def serialize(input: TestUserClass): Int = 1 + + override def deserialize(datum: Any): TestUserClass = new TestUserClass + + override def userClass: Class[TestUserClass] = classOf[TestUserClass] + + private[spark] override def asNullable: TestUserClassUDT = this + + override def hashCode(): Int = classOf[TestUserClassUDT].getName.hashCode() + + override def equals(other: Any): Boolean = other match { + case _: TestUserClassUDT => true + case _ => false + } +} + +class UDTRegistrationSuite extends SparkFunSuite { + + test("register non-UserDefinedType") { + UDTRegistration.register(classOf[TestUserClass].getName, + "org.apache.spark.sql.NonUserDefinedType") + intercept[SparkException] { + UDTRegistration.getUDTFor(classOf[TestUserClass].getName) + } + } + + test("default UDTs") { + val userClasses = Seq( + "org.apache.spark.ml.linalg.Vector", + "org.apache.spark.ml.linalg.DenseVector", + "org.apache.spark.ml.linalg.SparseVector", + "org.apache.spark.ml.linalg.Matrix", + "org.apache.spark.ml.linalg.DenseMatrix", + "org.apache.spark.ml.linalg.SparseMatrix") + userClasses.foreach { c => + assert(UDTRegistration.exists(c)) + } + } + + test("query registered user class") { + UDTRegistration.register(classOf[TestUserClass2].getName, classOf[TestUserClassUDT].getName) + assert(UDTRegistration.exists(classOf[TestUserClass2].getName)) + assert( + classOf[UserDefinedType[_]].isAssignableFrom(( + UDTRegistration.getUDTFor(classOf[TestUserClass2].getName).get))) + } + + test("query unregistered user class") { + assert(!UDTRegistration.exists(classOf[TestUserClass3].getName)) + assert(!UDTRegistration.getUDTFor(classOf[TestUserClass3].getName).isDefined) + } +}