-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-14487][SQL] User Defined Type registration without SQLUserDefinedType annotation #12259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 25 commits
b270bb9
b3acdb9
1f58662
a84d9dd
e59f1d5
eafdd58
9fa56d1
8000a4a
1d9b87b
ba9e910
2f9472c
5ecfd63
909c01f
6cdc446
85d1df0
744b659
ace8723
5aaf9db
42ff175
5caa0aa
434b1b4
023d281
3917f6b
06bdbc5
de7ea5d
5d4ba3d
45e87d9
9ed0f30
1c230ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.linalg | ||
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.GenericMutableRow | ||
| import org.apache.spark.sql.catalyst.util.GenericArrayData | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| /** | ||
| * User-defined type for [[Matrix]] in [[mllib-local]] which allows easy interaction with SQL | ||
| * via [[org.apache.spark.sql.Dataset]]. | ||
| */ | ||
| private[ml] class MatrixUDT extends UserDefinedType[Matrix] { | ||
|
|
||
| override def sqlType: StructType = { | ||
| // type: 0 = sparse, 1 = dense | ||
| // the dense matrix is built by numRows, numCols, values and isTransposed, all of which are | ||
| // set as not nullable, except values since in the future, support for binary matrices might | ||
| // be added for which values are not needed. | ||
| // the sparse matrix needs colPtrs and rowIndices, which are set as | ||
| // null, while building the dense matrix. | ||
| StructType(Seq( | ||
| StructField("type", ByteType, nullable = false), | ||
| StructField("numRows", IntegerType, nullable = false), | ||
| StructField("numCols", IntegerType, nullable = false), | ||
| StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true), | ||
| StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true), | ||
| StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true), | ||
| StructField("isTransposed", BooleanType, nullable = false) | ||
| )) | ||
| } | ||
|
|
||
| override def serialize(obj: Matrix): InternalRow = { | ||
| val row = new GenericMutableRow(7) | ||
| obj match { | ||
| case sm: SparseMatrix => | ||
| row.setByte(0, 0) | ||
| row.setInt(1, sm.numRows) | ||
| row.setInt(2, sm.numCols) | ||
| row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any]))) | ||
| row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any]))) | ||
| row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any]))) | ||
| row.setBoolean(6, sm.isTransposed) | ||
|
|
||
| case dm: DenseMatrix => | ||
| row.setByte(0, 1) | ||
| row.setInt(1, dm.numRows) | ||
| row.setInt(2, dm.numCols) | ||
| row.setNullAt(3) | ||
| row.setNullAt(4) | ||
| row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any]))) | ||
| row.setBoolean(6, dm.isTransposed) | ||
| } | ||
| row | ||
| } | ||
|
|
||
| override def deserialize(datum: Any): Matrix = { | ||
| datum match { | ||
| case row: InternalRow => | ||
| require(row.numFields == 7, | ||
| s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7") | ||
| val tpe = row.getByte(0) | ||
| val numRows = row.getInt(1) | ||
| val numCols = row.getInt(2) | ||
| val values = row.getArray(5).toDoubleArray() | ||
| val isTransposed = row.getBoolean(6) | ||
| tpe match { | ||
| case 0 => | ||
| val colPtrs = row.getArray(3).toIntArray() | ||
| val rowIndices = row.getArray(4).toIntArray() | ||
| new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) | ||
| case 1 => | ||
| new DenseMatrix(numRows, numCols, values, isTransposed) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| override def userClass: Class[Matrix] = classOf[Matrix] | ||
|
|
||
| override def equals(o: Any): Boolean = { | ||
| o match { | ||
| case v: MatrixUDT => true | ||
| case _ => false | ||
| } | ||
| } | ||
|
|
||
| // see [SPARK-8647], this achieves the needed constant hash code without constant no. | ||
| override def hashCode(): Int = classOf[MatrixUDT].getName.hashCode() | ||
|
|
||
| override def typeName: String = "matrix" | ||
|
|
||
| override def pyUDT: String = "pyspark.ml.linalg.MatrixUDT" | ||
|
|
||
| private[spark] override def asNullable: MatrixUDT = this | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.linalg | ||
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.GenericMutableRow | ||
| import org.apache.spark.sql.catalyst.util.GenericArrayData | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| /** | ||
| * User-defined type for [[Vector]] in [[mllib-local]] which allows easy interaction with SQL | ||
| * via [[org.apache.spark.sql.Dataset]]. | ||
| */ | ||
| private[ml] class VectorUDT extends UserDefinedType[Vector] { | ||
|
|
||
| override def sqlType: StructType = { | ||
| // type: 0 = sparse, 1 = dense | ||
| // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse | ||
| // vectors. The "values" field is nullable because we might want to add binary vectors later, | ||
| // which uses "size" and "indices", but not "values". | ||
| StructType(Seq( | ||
| StructField("type", ByteType, nullable = false), | ||
| StructField("size", IntegerType, nullable = true), | ||
| StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true), | ||
| StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) | ||
| } | ||
|
|
||
| override def serialize(obj: Vector): InternalRow = { | ||
| obj match { | ||
| case SparseVector(size, indices, values) => | ||
| val row = new GenericMutableRow(4) | ||
| row.setByte(0, 0) | ||
| row.setInt(1, size) | ||
| row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any]))) | ||
| row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) | ||
| row | ||
| case DenseVector(values) => | ||
| val row = new GenericMutableRow(4) | ||
| row.setByte(0, 1) | ||
| row.setNullAt(1) | ||
| row.setNullAt(2) | ||
| row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) | ||
| row | ||
| } | ||
| } | ||
|
|
||
| override def deserialize(datum: Any): Vector = { | ||
| datum match { | ||
| case row: InternalRow => | ||
| require(row.numFields == 4, | ||
| s"VectorUDT.deserialize given row with length ${row.numFields} but requires length == 4") | ||
| val tpe = row.getByte(0) | ||
| tpe match { | ||
| case 0 => | ||
| val size = row.getInt(1) | ||
| val indices = row.getArray(2).toIntArray() | ||
| val values = row.getArray(3).toDoubleArray() | ||
| new SparseVector(size, indices, values) | ||
| case 1 => | ||
| val values = row.getArray(3).toDoubleArray() | ||
| new DenseVector(values) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| override def pyUDT: String = "pyspark.ml.linalg.VectorUDT" | ||
|
|
||
| override def userClass: Class[Vector] = classOf[Vector] | ||
|
|
||
| override def equals(o: Any): Boolean = { | ||
| o match { | ||
| case v: VectorUDT => true | ||
| case _ => false | ||
| } | ||
| } | ||
|
|
||
| // see [SPARK-8647], this achieves the needed constant hash code without constant no. | ||
| override def hashCode(): Int = classOf[VectorUDT].getName.hashCode() | ||
|
|
||
| override def typeName: String = "vector" | ||
|
|
||
| private[spark] override def asNullable: VectorUDT = this | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.linalg | ||
|
|
||
| import scala.beans.{BeanInfo, BeanProperty} | ||
|
|
||
| import org.apache.spark.{SparkException, SparkFunSuite} | ||
| import org.apache.spark.ml.linalg._ | ||
| import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.Row | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| @BeanInfo | ||
| private[ml] case class MyMatrixPoint( | ||
| @BeanProperty label: Double, | ||
| @BeanProperty matrix: Matrix) | ||
|
|
||
| class MatrixUDTSuite extends SparkFunSuite with MLlibTestSparkContext { | ||
|
||
| import testImplicits._ | ||
|
|
||
| test("preloaded MatrixUDT") { | ||
| val dm1 = new DenseMatrix(2, 2, Array(0.9, 1.2, 2.3, 9.8)) | ||
| val dm2 = new DenseMatrix(3, 2, Array(0.0, 1.21, 2.3, 9.8, 9.0, 0.0)) | ||
| val dm3 = new DenseMatrix(0, 0, Array()) | ||
| val sm1 = dm1.toSparse | ||
| val sm2 = dm2.toSparse | ||
| val sm3 = dm3.toSparse | ||
|
|
||
| val matrixDF = Seq( | ||
| MyMatrixPoint(1.0, dm1), | ||
| MyMatrixPoint(2.0, dm2), | ||
| MyMatrixPoint(3.0, dm3), | ||
| MyMatrixPoint(4.0, sm1), | ||
| MyMatrixPoint(5.0, sm2), | ||
| MyMatrixPoint(6.0, sm3)).toDF() | ||
|
|
||
| val labels = matrixDF.select('label).as[Double].collect() | ||
| assert(labels.size === 6) | ||
| assert(labels.sorted === Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) | ||
|
|
||
| val matrices = matrixDF.select('matrix).rdd.map { case Row(m: Matrix) => m }.collect() | ||
| assert(matrices.contains(dm1)) | ||
| assert(matrices.contains(dm2)) | ||
| assert(matrices.contains(dm3)) | ||
| assert(matrices.contains(sm1)) | ||
| assert(matrices.contains(sm2)) | ||
| assert(matrices.contains(sm3)) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.linalg | ||
|
|
||
| import scala.beans.{BeanInfo, BeanProperty} | ||
|
|
||
| import org.apache.spark.{SparkException, SparkFunSuite} | ||
| import org.apache.spark.ml.linalg._ | ||
| import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.Row | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| @BeanInfo | ||
| private[ml] case class MyVectorPoint( | ||
|
||
| @BeanProperty label: Double, | ||
| @BeanProperty vector: Vector) | ||
|
|
||
| class VectorUDTSuite extends SparkFunSuite with MLlibTestSparkContext { | ||
| import testImplicits._ | ||
|
|
||
| test("preloaded VectorUDT") { | ||
| val dv0 = Vectors.dense(Array.empty[Double]) | ||
| val dv1 = Vectors.dense(1.0, 2.0) | ||
| val sv0 = Vectors.sparse(2, Array.empty, Array.empty) | ||
| val sv1 = Vectors.sparse(2, Array(1), Array(2.0)) | ||
|
|
||
| val vectorDF = Seq( | ||
| MyVectorPoint(1.0, dv0), | ||
| MyVectorPoint(2.0, dv1), | ||
| MyVectorPoint(3.0, sv0), | ||
| MyVectorPoint(4.0, sv1)).toDF() | ||
|
|
||
| val labels = vectorDF.select('label).as[Double].collect() | ||
| assert(labels.size === 4) | ||
| assert(labels.sorted === Array(1.0, 2.0, 3.0, 4.0)) | ||
|
|
||
| val vectors = vectorDF.select('vector).rdd.map { case Row(v: Vector) => v }.collect() | ||
| assert(vectors.contains(dv0)) | ||
| assert(vectors.contains(dv1)) | ||
| assert(vectors.contains(sv0)) | ||
| assert(vectors.contains(sv1)) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,14 +24,18 @@ import org.scalatest.Suite | |
|
|
||
| import org.apache.spark.{SparkConf, SparkContext} | ||
| import org.apache.spark.ml.util.TempDirectory | ||
| import org.apache.spark.sql.SQLContext | ||
| import org.apache.spark.sql.{SQLContext, SQLImplicits} | ||
| import org.apache.spark.util.Utils | ||
|
|
||
| trait MLlibTestSparkContext extends TempDirectory { self: Suite => | ||
| @transient var sc: SparkContext = _ | ||
| @transient var sqlContext: SQLContext = _ | ||
| @transient var checkpointDir: String = _ | ||
|
|
||
| protected object testImplicits extends SQLImplicits { | ||
|
||
| protected override def _sqlContext: SQLContext = self.sqlContext | ||
| } | ||
|
|
||
| override def beforeAll() { | ||
| super.beforeAll() | ||
| val conf = new SparkConf() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not part of this PR, but we might need to fix this. It seems that it boxes primitive arrays. I created https://issues.apache.org/jira/browse/SPARK-14850 to track the issue. cc: @cloud-fan