Skip to content

Commit 424e987

Browse files
mengxrliancheng
authored andcommitted
[SPARK-6672][SQL] convert row to catalyst in createDataFrame(RDD[Row], ...)
We assume that `RDD[Row]` contains Scala types. So we need to convert them into catalyst types in createDataFrame. liancheng Author: Xiangrui Meng <[email protected]> Closes #5329 from mengxr/SPARK-6672 and squashes the following commits: 2d52644 [Xiangrui Meng] set needsConversion = false in jsonRDD 06896e4 [Xiangrui Meng] add createDataFrame without conversion 4a3767b [Xiangrui Meng] convert Row to catalyst
1 parent 6562787 commit 424e987

File tree

7 files changed

+37
-8
lines changed

7 files changed

+37
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ trait ScalaReflection {
7272
case (d: BigDecimal, _) => Decimal(d)
7373
case (d: java.math.BigDecimal, _) => Decimal(d)
7474
case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d)
75+
case (r: Row, structType: StructType) =>
76+
new GenericRow(
77+
r.toSeq.zip(structType.fields).map { case (elem, field) =>
78+
convertToCatalyst(elem, field.dataType)
79+
}.toArray)
7580
case (other, _) => other
7681
}
7782

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,8 @@ class DataFrame private[sql](
904904
*/
905905
override def repartition(numPartitions: Int): DataFrame = {
906906
sqlContext.createDataFrame(
907-
queryExecution.toRdd.map(_.copy()).repartition(numPartitions), schema)
907+
queryExecution.toRdd.map(_.copy()).repartition(numPartitions),
908+
schema, needsConversion = false)
908909
}
909910

910911
/**

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -392,9 +392,23 @@ class SQLContext(@transient val sparkContext: SparkContext)
392392
*/
393393
@DeveloperApi
394394
def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = {
395+
createDataFrame(rowRDD, schema, needsConversion = true)
396+
}
397+
398+
/**
399+
* Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be
400+
* converted to Catalyst rows.
401+
*/
402+
private[sql]
403+
def createDataFrame(rowRDD: RDD[Row], schema: StructType, needsConversion: Boolean) = {
395404
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
396405
// schema differs from the existing schema on any field data type.
397-
val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self)
406+
val catalystRows = if (needsConversion) {
407+
rowRDD.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row])
408+
} else {
409+
rowRDD
410+
}
411+
val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self)
398412
DataFrame(this, logicalPlan)
399413
}
400414

@@ -604,7 +618,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
604618
JsonRDD.nullTypeToStringType(
605619
JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
606620
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
607-
createDataFrame(rowRDD, appliedSchema)
621+
createDataFrame(rowRDD, appliedSchema, needsConversion = false)
608622
}
609623

610624
/**
@@ -633,7 +647,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
633647
JsonRDD.nullTypeToStringType(
634648
JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
635649
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
636-
createDataFrame(rowRDD, appliedSchema)
650+
createDataFrame(rowRDD, appliedSchema, needsConversion = false)
637651
}
638652

639653
/**

sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ private[sql] class DefaultSource
122122
val df =
123123
sqlContext.createDataFrame(
124124
data.queryExecution.toRdd,
125-
data.schema.asNullable)
125+
data.schema.asNullable,
126+
needsConversion = false)
126127
val createdRelation =
127128
createRelation(sqlContext, parameters, df.schema).asInstanceOf[ParquetRelation2]
128129
createdRelation.insert(df, overwrite = mode == SaveMode.Overwrite)

sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ private[sql] case class InsertIntoDataSource(
3131
val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
3232
val data = DataFrame(sqlContext, query)
3333
// Apply the schema of the existing table to the new data.
34-
val df = sqlContext.createDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
34+
val df = sqlContext.createDataFrame(
35+
data.queryExecution.toRdd, logicalRelation.schema, needsConversion = false)
3536
relation.insert(df, overwrite)
3637

3738
// Invalidate the cache.

sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.types._
2828
* @param y y coordinate
2929
*/
3030
@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
31-
private[sql] class ExamplePoint(val x: Double, val y: Double)
31+
private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable
3232

3333
/**
3434
* User-defined type for [[ExamplePoint]].

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.language.postfixOps
2121

2222
import org.apache.spark.sql.functions._
2323
import org.apache.spark.sql.types._
24-
import org.apache.spark.sql.test.TestSQLContext
24+
import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext}
2525
import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery
2626
import org.apache.spark.sql.test.TestSQLContext.implicits._
2727
import org.apache.spark.sql.test.TestSQLContext.sql
@@ -506,4 +506,11 @@ class DataFrameSuite extends QueryTest {
506506
testData.select($"*").show()
507507
testData.select($"*").show(1000)
508508
}
509+
510+
test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") {
511+
val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0))))
512+
val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false)))
513+
val df = TestSQLContext.createDataFrame(rowRDD, schema)
514+
df.rdd.collect()
515+
}
509516
}

0 commit comments

Comments
 (0)