-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-12438][SQL] Add SQLUserDefinedType support for encoder #10390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
089ab18
51b76b1
d01c661
42303d2
72446f1
c76ea47
62fa738
79e6ec9
80a3f7b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,12 +17,21 @@ | |
|
|
||
| package org.apache.spark.sql | ||
|
|
||
| import java.util.concurrent.ConcurrentMap | ||
|
|
||
| import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} | ||
|
|
||
| import scala.beans.{BeanInfo, BeanProperty} | ||
| import scala.reflect.runtime.universe.TypeTag | ||
|
|
||
| import com.google.common.collect.MapMaker | ||
|
|
||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.catalyst.CatalystTypeConverters | ||
| import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection} | ||
| import org.apache.spark.sql.catalyst.encoders._ | ||
| import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder | ||
| import org.apache.spark.sql.catalyst.expressions.UnsafeRow | ||
| import org.apache.spark.sql.catalyst.plans.logical.LocalRelation | ||
| import org.apache.spark.sql.execution.datasources.parquet.ParquetTest | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.test.SharedSQLContext | ||
|
|
@@ -89,6 +98,30 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT | |
| assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) | ||
| } | ||
|
|
||
| private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() | ||
| outers.put(getClass.getName, this) | ||
|
|
||
|
||
| test("user type with ScalaReflection") { | ||
| val points = Seq( | ||
| MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), | ||
| MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) | ||
|
|
||
| val schema = ScalaReflection.schemaFor[MyLabeledPoint].dataType.asInstanceOf[StructType] | ||
| val attributeSeq = schema.toAttributes | ||
|
|
||
| val pointEncoder = encoderFor[MyLabeledPoint] | ||
| val unsafeRows = points.map(pointEncoder.toRow(_).copy()) | ||
|
||
| val df = DataFrame(sqlContext, LocalRelation(attributeSeq, unsafeRows)) | ||
| val decodedPoints = df.collect() | ||
| points.zip(decodedPoints).foreach { case (p, p2) => | ||
| assert(p.label == p2(0) && p.features == p2(1)) | ||
| } | ||
|
|
||
| val boundEncoder = pointEncoder.resolve(attributeSeq, outers).bind(attributeSeq) | ||
| val point = MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))) | ||
| assert(boundEncoder.fromRow(boundEncoder.toRow(point)) === point) | ||
| } | ||
|
|
||
| test("UDTs and UDFs") { | ||
| sqlContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) | ||
| pointsRDD.registerTempTable("points") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this worth another JIRA.
Is
ScalaRelfectionthe only place that may use UDT inCast?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, as it is a special use case. I will open another JIRA for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fixing is submitted as pr #10410.