File tree Expand file tree Collapse file tree 4 files changed +16
-3
lines changed
catalyst/src/main/scala/org/apache/spark/sql/catalyst
main/scala/org/apache/spark/sql
test/scala/org/apache/spark/sql Expand file tree Collapse file tree 4 files changed +16
-3
lines changed Original file line number Diff line number Diff line change @@ -72,6 +72,11 @@ trait ScalaReflection {
7272 case (d : BigDecimal , _) => Decimal (d)
7373 case (d : java.math.BigDecimal , _) => Decimal (d)
7474 case (d : java.sql.Date , _) => DateUtils .fromJavaDate(d)
75+ case (r : Row , structType : StructType ) =>
76+ new GenericRow (
77+ r.toSeq.zip(structType.fields).map { case (elem, field) =>
78+ convertToCatalyst(elem, field.dataType)
79+ }.toArray)
7580 case (other, _) => other
7681 }
7782
Original file line number Diff line number Diff line change @@ -394,7 +394,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
394394 def createDataFrame (rowRDD : RDD [Row ], schema : StructType ): DataFrame = {
395395 // TODO: use MutableProjection when rowRDD is another DataFrame and the applied
396396 // schema differs from the existing schema on any field data type.
397- val logicalPlan = LogicalRDD (schema.toAttributes, rowRDD)(self)
397+ val catalystRows = rowRDD.map(ScalaReflection .convertToCatalyst(_, schema).asInstanceOf [Row ])
398+ val logicalPlan = LogicalRDD (schema.toAttributes, catalystRows)(self)
398399 DataFrame (this , logicalPlan)
399400 }
400401
Original file line number Diff line number Diff line change @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._
2828 * @param y y coordinate
2929 */
3030@ SQLUserDefinedType (udt = classOf [ExamplePointUDT ])
31- private [sql] class ExamplePoint (val x : Double , val y : Double )
31+ private [sql] class ExamplePoint (val x : Double , val y : Double ) extends Serializable
3232
3333/**
3434 * User-defined type for [[ExamplePoint ]].
Original file line number Diff line number Diff line change @@ -21,7 +21,7 @@ import scala.language.postfixOps
2121
2222import org .apache .spark .sql .functions ._
2323import org .apache .spark .sql .types ._
24- import org .apache .spark .sql .test .TestSQLContext
24+ import org .apache .spark .sql .test .{ ExamplePointUDT , ExamplePoint , TestSQLContext }
2525import org .apache .spark .sql .test .TestSQLContext .logicalPlanToSparkQuery
2626import org .apache .spark .sql .test .TestSQLContext .implicits ._
2727import org .apache .spark .sql .test .TestSQLContext .sql
@@ -506,4 +506,11 @@ class DataFrameSuite extends QueryTest {
506506 testData.select($" *" ).show()
507507 testData.select($" *" ).show(1000 )
508508 }
509+
510+ test(" createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)" ) {
511+ val rowRDD = TestSQLContext .sparkContext.parallelize(Seq (Row (new ExamplePoint (1.0 , 2.0 ))))
512+ val schema = StructType (Array (StructField (" point" , new ExamplePointUDT (), false )))
513+ val df = TestSQLContext .createDataFrame(rowRDD, schema)
514+ df.rdd.collect()
515+ }
509516}
You can’t perform that action at this time.
0 commit comments