@@ -32,6 +32,7 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil
3232import org .apache .spark .mapreduce .SparkHadoopMapReduceUtil
3333import org .apache .spark .sql .catalyst .CatalystTypeConverters
3434import org .apache .spark .sql .catalyst .expressions ._
35+ import org .apache .spark .sql .catalyst .expressions .codegen .GenerateProjection
3536import org .apache .spark .sql .catalyst .plans .logical .LogicalPlan
3637import org .apache .spark .sql .execution .RunnableCommand
3738import org .apache .spark .sql .{DataFrame , SQLContext , SaveMode }
@@ -102,14 +103,18 @@ private[sql] case class InsertIntoFSBasedRelation(
102103 } else {
103104 val writerContainer = new DynamicPartitionWriterContainer (
104105 relation, job, partitionColumns, " __HIVE_DEFAULT_PARTITION__" )
105- insertWithDynamicPartitions(writerContainer, df, partitionColumns)
106+ insertWithDynamicPartitions(sqlContext, writerContainer, df, partitionColumns)
106107 }
107108 }
108109
109110 Seq .empty[Row ]
110111 }
111112
112113 private def insert (writerContainer : BaseWriterContainer , df : DataFrame ): Unit = {
114+ // Uses local vals for serialization
115+ val needsConversion = relation.needConversion
116+ val dataSchema = relation.dataSchema
117+
113118 try {
114119 writerContainer.driverSideSetup()
115120 df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
@@ -124,8 +129,8 @@ private[sql] case class InsertIntoFSBasedRelation(
124129 writerContainer.executorSideSetup(taskContext)
125130
126131 try {
127- if (relation.needConversion ) {
128- val converter = CatalystTypeConverters .createToScalaConverter(relation. dataSchema)
132+ if (needsConversion ) {
133+ val converter = CatalystTypeConverters .createToScalaConverter(dataSchema)
129134 while (iterator.hasNext) {
130135 val row = converter(iterator.next()).asInstanceOf [Row ]
131136 writerContainer.outputWriterForRow(row).write(row)
@@ -145,9 +150,13 @@ private[sql] case class InsertIntoFSBasedRelation(
145150 }
146151
147152 private def insertWithDynamicPartitions (
153+ sqlContext : SQLContext ,
148154 writerContainer : BaseWriterContainer ,
149155 df : DataFrame ,
150156 partitionColumns : Array [String ]): Unit = {
157+ // Uses a local val for serialization
158+ val needsConversion = relation.needConversion
159+ val dataSchema = relation.dataSchema
151160
152161 require(
153162 df.schema == relation.schema,
@@ -156,34 +165,21 @@ private[sql] case class InsertIntoFSBasedRelation(
156165 |Relation schema: ${relation.schema}
157166 """ .stripMargin)
158167
159- val sqlContext = df.sqlContext
160-
161- val (partitionRDD, dataRDD) = {
162- val fieldNames = relation.schema.fieldNames
163- val dataCols = fieldNames.filterNot(partitionColumns.contains)
164- val df = sqlContext.createDataFrame(
165- DataFrame (sqlContext, query).queryExecution.toRdd,
166- relation.schema,
167- needsConversion = false )
168-
169- val partitionColumnsInSpec = relation.partitionSpec.partitionColumns.map(_.name)
170- require(
171- partitionColumnsInSpec.sameElements(partitionColumns),
172- s """ Partition columns mismatch.
173- |Expected: ${partitionColumnsInSpec.mkString(" , " )}
174- |Actual: ${partitionColumns.mkString(" , " )}
175- """ .stripMargin)
176-
177- val partitionDF = df.select(partitionColumns.head, partitionColumns.tail: _* )
178- val dataDF = df.select(dataCols.head, dataCols.tail: _* )
168+ val partitionColumnsInSpec = relation.partitionColumns.fieldNames
169+ require(
170+ partitionColumnsInSpec.sameElements(partitionColumns),
171+ s """ Partition columns mismatch.
172+ |Expected: ${partitionColumnsInSpec.mkString(" , " )}
173+ |Actual: ${partitionColumns.mkString(" , " )}
174+ """ .stripMargin)
179175
180- partitionDF .queryExecution.executedPlan.execute() ->
181- dataDF.queryExecution.executedPlan.execute( )
182- }
176+ val output = df .queryExecution.executedPlan.output
177+ val (partitionOutput, dataOutput) = output.partition(a => partitionColumns.contains(a.name) )
178+ val codegenEnabled = df.sqlContext.conf.codegenEnabled
183179
184180 try {
185181 writerContainer.driverSideSetup()
186- sqlContext.sparkContext.runJob(partitionRDD.zip(dataRDD ), writeRows _)
182+ df. sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute( ), writeRows _)
187183 writerContainer.commitJob()
188184 relation.refresh()
189185 } catch { case cause : Throwable =>
@@ -192,20 +188,44 @@ private[sql] case class InsertIntoFSBasedRelation(
192188 throw new SparkException (" Job aborted." , cause)
193189 }
194190
195- def writeRows (taskContext : TaskContext , iterator : Iterator [( Row , Row ) ]): Unit = {
191+ def writeRows (taskContext : TaskContext , iterator : Iterator [Row ]): Unit = {
196192 writerContainer.executorSideSetup(taskContext)
197193
198- try {
194+ val partitionProj = newProjection(codegenEnabled, partitionOutput, output)
195+ val dataProj = newProjection(codegenEnabled, dataOutput, output)
196+
197+ if (needsConversion) {
198+ val converter = CatalystTypeConverters .createToScalaConverter(dataSchema)
199199 while (iterator.hasNext) {
200- val (partitionValues, data) = iterator.next()
201- writerContainer.outputWriterForRow(partitionValues).write(data)
200+ val row = converter(iterator.next()).asInstanceOf [Row ]
201+ val partitionPart = partitionProj(row)
202+ val dataPart = dataProj(row)
203+ writerContainer.outputWriterForRow(partitionPart).write(dataPart)
204+ }
205+ } else {
206+ while (iterator.hasNext) {
207+ val row = iterator.next()
208+ val partitionPart = partitionProj(row)
209+ val dataPart = dataProj(row)
210+ writerContainer.outputWriterForRow(partitionPart).write(dataPart)
202211 }
203-
204- writerContainer.commitTask()
205- } catch { case cause : Throwable =>
206- writerContainer.abortTask()
207- throw new SparkException (" Task failed while writing rows." , cause)
208212 }
213+
214+ writerContainer.commitTask()
215+ }
216+ }
217+
218+ // This is copied from SparkPlan, probably should move this to a more general place.
219+ private def newProjection (
220+ codegenEnabled : Boolean ,
221+ expressions : Seq [Expression ],
222+ inputSchema : Seq [Attribute ]): Projection = {
223+ log.debug(
224+ s " Creating Projection: $expressions, inputSchema: $inputSchema, codegen: $codegenEnabled" )
225+ if (codegenEnabled) {
226+ GenerateProjection .generate(expressions, inputSchema)
227+ } else {
228+ new InterpretedProjection (expressions, inputSchema)
209229 }
210230 }
211231}
@@ -379,6 +399,10 @@ private[sql] class DynamicPartitionWriterContainer(
379399}
380400
381401private [sql] object DynamicPartitionWriterContainer {
402+ // ////////////////////////////////////////////////////////////////////////////////////////////////
403+ // The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils).
404+ // ////////////////////////////////////////////////////////////////////////////////////////////////
405+
382406 val charToEscape = {
383407 val bitSet = new java.util.BitSet (128 )
384408
0 commit comments