-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-14487][SQL] User Defined Type registration without SQLUserDefinedType annotation #12259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 28 commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
b270bb9
init import.
viirya b3acdb9
Revert change to Vectors.
viirya 1f58662
Both test UDT through UDTRegistration and SQLUserDefinedType.
viirya a84d9dd
Check Utils.classIsLoadable.
viirya e59f1d5
Separate two test suites to prevent interfere with other tests.
viirya eafdd58
Remove SQLUserDefinedType for MyDVector.
viirya 9fa56d1
Merge remote-tracking branch 'upstream/master' into improve-sql-usertype
viirya 8000a4a
Copy VectorUDT and MatrixUDT into sql.
viirya 1d9b87b
Preload built-in UDTs.
viirya ba9e910
Add test for pre-loaded VectorUDT.
viirya 2f9472c
Fix scala style.
viirya 5ecfd63
Show log info when trying to register an already registered user class.
viirya 909c01f
Add test for MatrixUDT.
viirya 6cdc446
Only use mllib-local as test dependency of sql-core. Use string-based…
viirya 85d1df0
Move VectorUDT and MatrixUDT tests to mllib.
viirya 744b659
Address comments.
viirya ace8723
Merge remote-tracking branch 'upstream/master' into improve-sql-usertype
viirya 5aaf9db
Fix test.
viirya 42ff175
Merge remote-tracking branch 'upstream/master' into improve-sql-usertype
viirya 5caa0aa
Add condition check back.
viirya 434b1b4
Merge remote-tracking branch 'upstream/master' into improve-sql-usertype
viirya 023d281
Address comments.
viirya 3917f6b
Fix scala style.
viirya 06bdbc5
Fix scala style.
viirya de7ea5d
Update pyUDT.
viirya 5d4ba3d
Simplify UDTRegistration.
viirya 45e87d9
Simplify two UDT test suites.
viirya 9ed0f30
Remove unnecessary change.
viirya 1c230ae
Remove empty line.
viirya File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
111 changes: 111 additions & 0 deletions
111
mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.linalg | ||
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.GenericMutableRow | ||
| import org.apache.spark.sql.catalyst.util.GenericArrayData | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| /** | ||
| * User-defined type for [[Matrix]] in [[mllib-local]] which allows easy interaction with SQL | ||
| * via [[org.apache.spark.sql.Dataset]]. | ||
| */ | ||
| private[ml] class MatrixUDT extends UserDefinedType[Matrix] { | ||
|
|
||
| override def sqlType: StructType = { | ||
| // type: 0 = sparse, 1 = dense | ||
| // the dense matrix is built by numRows, numCols, values and isTransposed, all of which are | ||
| // set as not nullable, except values since in the future, support for binary matrices might | ||
| // be added for which values are not needed. | ||
| // the sparse matrix needs colPtrs and rowIndices, which are set as | ||
| // null, while building the dense matrix. | ||
| StructType(Seq( | ||
| StructField("type", ByteType, nullable = false), | ||
| StructField("numRows", IntegerType, nullable = false), | ||
| StructField("numCols", IntegerType, nullable = false), | ||
| StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true), | ||
| StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true), | ||
| StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true), | ||
| StructField("isTransposed", BooleanType, nullable = false) | ||
| )) | ||
| } | ||
|
|
||
| override def serialize(obj: Matrix): InternalRow = { | ||
| val row = new GenericMutableRow(7) | ||
| obj match { | ||
| case sm: SparseMatrix => | ||
| row.setByte(0, 0) | ||
| row.setInt(1, sm.numRows) | ||
| row.setInt(2, sm.numCols) | ||
| row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any]))) | ||
| row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any]))) | ||
| row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any]))) | ||
| row.setBoolean(6, sm.isTransposed) | ||
|
|
||
| case dm: DenseMatrix => | ||
| row.setByte(0, 1) | ||
| row.setInt(1, dm.numRows) | ||
| row.setInt(2, dm.numCols) | ||
| row.setNullAt(3) | ||
| row.setNullAt(4) | ||
| row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any]))) | ||
| row.setBoolean(6, dm.isTransposed) | ||
| } | ||
| row | ||
| } | ||
|
|
||
| override def deserialize(datum: Any): Matrix = { | ||
| datum match { | ||
| case row: InternalRow => | ||
| require(row.numFields == 7, | ||
| s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7") | ||
| val tpe = row.getByte(0) | ||
| val numRows = row.getInt(1) | ||
| val numCols = row.getInt(2) | ||
| val values = row.getArray(5).toDoubleArray() | ||
| val isTransposed = row.getBoolean(6) | ||
| tpe match { | ||
| case 0 => | ||
| val colPtrs = row.getArray(3).toIntArray() | ||
| val rowIndices = row.getArray(4).toIntArray() | ||
| new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) | ||
| case 1 => | ||
| new DenseMatrix(numRows, numCols, values, isTransposed) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| override def userClass: Class[Matrix] = classOf[Matrix] | ||
|
|
||
| override def equals(o: Any): Boolean = { | ||
| o match { | ||
| case v: MatrixUDT => true | ||
| case _ => false | ||
| } | ||
| } | ||
|
|
||
| // see [SPARK-8647], this achieves the needed constant hash code without constant no. | ||
| override def hashCode(): Int = classOf[MatrixUDT].getName.hashCode() | ||
|
|
||
| override def typeName: String = "matrix" | ||
|
|
||
| override def pyUDT: String = "pyspark.ml.linalg.MatrixUDT" | ||
|
|
||
| private[spark] override def asNullable: MatrixUDT = this | ||
| } | ||
98 changes: 98 additions & 0 deletions
98
mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.linalg | ||
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.GenericMutableRow | ||
| import org.apache.spark.sql.catalyst.util.GenericArrayData | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| /** | ||
| * User-defined type for [[Vector]] in [[mllib-local]] which allows easy interaction with SQL | ||
| * via [[org.apache.spark.sql.Dataset]]. | ||
| */ | ||
| private[ml] class VectorUDT extends UserDefinedType[Vector] { | ||
|
|
||
| override def sqlType: StructType = { | ||
| // type: 0 = sparse, 1 = dense | ||
| // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse | ||
| // vectors. The "values" field is nullable because we might want to add binary vectors later, | ||
| // which uses "size" and "indices", but not "values". | ||
| StructType(Seq( | ||
| StructField("type", ByteType, nullable = false), | ||
| StructField("size", IntegerType, nullable = true), | ||
| StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true), | ||
| StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) | ||
| } | ||
|
|
||
| override def serialize(obj: Vector): InternalRow = { | ||
| obj match { | ||
| case SparseVector(size, indices, values) => | ||
| val row = new GenericMutableRow(4) | ||
| row.setByte(0, 0) | ||
| row.setInt(1, size) | ||
| row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any]))) | ||
| row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) | ||
| row | ||
| case DenseVector(values) => | ||
| val row = new GenericMutableRow(4) | ||
| row.setByte(0, 1) | ||
| row.setNullAt(1) | ||
| row.setNullAt(2) | ||
| row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) | ||
| row | ||
| } | ||
| } | ||
|
|
||
| override def deserialize(datum: Any): Vector = { | ||
| datum match { | ||
| case row: InternalRow => | ||
| require(row.numFields == 4, | ||
| s"VectorUDT.deserialize given row with length ${row.numFields} but requires length == 4") | ||
| val tpe = row.getByte(0) | ||
| tpe match { | ||
| case 0 => | ||
| val size = row.getInt(1) | ||
| val indices = row.getArray(2).toIntArray() | ||
| val values = row.getArray(3).toDoubleArray() | ||
| new SparseVector(size, indices, values) | ||
| case 1 => | ||
| val values = row.getArray(3).toDoubleArray() | ||
| new DenseVector(values) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| override def pyUDT: String = "pyspark.ml.linalg.VectorUDT" | ||
|
|
||
| override def userClass: Class[Vector] = classOf[Vector] | ||
|
|
||
| override def equals(o: Any): Boolean = { | ||
| o match { | ||
| case v: VectorUDT => true | ||
| case _ => false | ||
| } | ||
| } | ||
|
|
||
| // see [SPARK-8647], this achieves the needed constant hash code without constant no. | ||
| override def hashCode(): Int = classOf[VectorUDT].getName.hashCode() | ||
|
|
||
| override def typeName: String = "vector" | ||
|
|
||
| private[spark] override def asNullable: VectorUDT = this | ||
| } |
41 changes: 41 additions & 0 deletions
41
mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.linalg | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| class MatrixUDTSuite extends SparkFunSuite { | ||
|
|
||
| test("preloaded MatrixUDT") { | ||
| val dm1 = new DenseMatrix(2, 2, Array(0.9, 1.2, 2.3, 9.8)) | ||
| val dm2 = new DenseMatrix(3, 2, Array(0.0, 1.21, 2.3, 9.8, 9.0, 0.0)) | ||
| val dm3 = new DenseMatrix(0, 0, Array()) | ||
| val sm1 = dm1.toSparse | ||
| val sm2 = dm2.toSparse | ||
| val sm3 = dm3.toSparse | ||
|
|
||
| for (m <- Seq(dm1, dm2, dm3, sm1, sm2, sm3)) { | ||
| val udt = UDTRegistration.getUDTFor(m.getClass.getName).get.newInstance() | ||
| .asInstanceOf[MatrixUDT] | ||
| assert(m === udt.deserialize(udt.serialize(m))) | ||
| assert(udt.typeName == "matrix") | ||
| assert(udt.simpleString == "matrix") | ||
| } | ||
| } | ||
| } |
39 changes: 39 additions & 0 deletions
39
mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.linalg | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| class VectorUDTSuite extends SparkFunSuite { | ||
|
|
||
| test("preloaded VectorUDT") { | ||
| val dv1 = Vectors.dense(Array.empty[Double]) | ||
| val dv2 = Vectors.dense(1.0, 2.0) | ||
| val sv1 = Vectors.sparse(2, Array.empty, Array.empty) | ||
| val sv2 = Vectors.sparse(2, Array(1), Array(2.0)) | ||
|
|
||
| for (v <- Seq(dv1, dv2, sv1, sv2)) { | ||
| val udt = UDTRegistration.getUDTFor(v.getClass.getName).get.newInstance() | ||
| .asInstanceOf[VectorUDT] | ||
| assert(v === udt.deserialize(udt.serialize(v))) | ||
| assert(udt.typeName == "vector") | ||
| assert(udt.simpleString == "vector") | ||
| } | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not part of this PR, but we might need to fix this. It seems that it boxes primitive arrays. I created https://issues.apache.org/jira/browse/SPARK-14850 to track the issue. cc: @cloud-fan