Skip to content

Commit bc3f9b4

Browse files
committed
Uses projection to separate partition columns and data columns while inserting rows
1 parent 795920a commit bc3f9b4

File tree

2 files changed

+61
-41
lines changed

2 files changed

+61
-41
lines changed

sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala

Lines changed: 60 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil
3232
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
3333
import org.apache.spark.sql.catalyst.CatalystTypeConverters
3434
import org.apache.spark.sql.catalyst.expressions._
35+
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection
3536
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
3637
import org.apache.spark.sql.execution.RunnableCommand
3738
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
@@ -102,14 +103,18 @@ private[sql] case class InsertIntoFSBasedRelation(
102103
} else {
103104
val writerContainer = new DynamicPartitionWriterContainer(
104105
relation, job, partitionColumns, "__HIVE_DEFAULT_PARTITION__")
105-
insertWithDynamicPartitions(writerContainer, df, partitionColumns)
106+
insertWithDynamicPartitions(sqlContext, writerContainer, df, partitionColumns)
106107
}
107108
}
108109

109110
Seq.empty[Row]
110111
}
111112

112113
private def insert(writerContainer: BaseWriterContainer, df: DataFrame): Unit = {
114+
// Uses local vals for serialization
115+
val needsConversion = relation.needConversion
116+
val dataSchema = relation.dataSchema
117+
113118
try {
114119
writerContainer.driverSideSetup()
115120
df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
@@ -124,8 +129,8 @@ private[sql] case class InsertIntoFSBasedRelation(
124129
writerContainer.executorSideSetup(taskContext)
125130

126131
try {
127-
if (relation.needConversion) {
128-
val converter = CatalystTypeConverters.createToScalaConverter(relation.dataSchema)
132+
if (needsConversion) {
133+
val converter = CatalystTypeConverters.createToScalaConverter(dataSchema)
129134
while (iterator.hasNext) {
130135
val row = converter(iterator.next()).asInstanceOf[Row]
131136
writerContainer.outputWriterForRow(row).write(row)
@@ -145,9 +150,13 @@ private[sql] case class InsertIntoFSBasedRelation(
145150
}
146151

147152
private def insertWithDynamicPartitions(
153+
sqlContext: SQLContext,
148154
writerContainer: BaseWriterContainer,
149155
df: DataFrame,
150156
partitionColumns: Array[String]): Unit = {
157+
// Uses a local val for serialization
158+
val needsConversion = relation.needConversion
159+
val dataSchema = relation.dataSchema
151160

152161
require(
153162
df.schema == relation.schema,
@@ -156,34 +165,21 @@ private[sql] case class InsertIntoFSBasedRelation(
156165
|Relation schema: ${relation.schema}
157166
""".stripMargin)
158167

159-
val sqlContext = df.sqlContext
160-
161-
val (partitionRDD, dataRDD) = {
162-
val fieldNames = relation.schema.fieldNames
163-
val dataCols = fieldNames.filterNot(partitionColumns.contains)
164-
val df = sqlContext.createDataFrame(
165-
DataFrame(sqlContext, query).queryExecution.toRdd,
166-
relation.schema,
167-
needsConversion = false)
168-
169-
val partitionColumnsInSpec = relation.partitionSpec.partitionColumns.map(_.name)
170-
require(
171-
partitionColumnsInSpec.sameElements(partitionColumns),
172-
s"""Partition columns mismatch.
173-
|Expected: ${partitionColumnsInSpec.mkString(", ")}
174-
|Actual: ${partitionColumns.mkString(", ")}
175-
""".stripMargin)
176-
177-
val partitionDF = df.select(partitionColumns.head, partitionColumns.tail: _*)
178-
val dataDF = df.select(dataCols.head, dataCols.tail: _*)
168+
val partitionColumnsInSpec = relation.partitionColumns.fieldNames
169+
require(
170+
partitionColumnsInSpec.sameElements(partitionColumns),
171+
s"""Partition columns mismatch.
172+
|Expected: ${partitionColumnsInSpec.mkString(", ")}
173+
|Actual: ${partitionColumns.mkString(", ")}
174+
""".stripMargin)
179175

180-
partitionDF.queryExecution.executedPlan.execute() ->
181-
dataDF.queryExecution.executedPlan.execute()
182-
}
176+
val output = df.queryExecution.executedPlan.output
177+
val (partitionOutput, dataOutput) = output.partition(a => partitionColumns.contains(a.name))
178+
val codegenEnabled = df.sqlContext.conf.codegenEnabled
183179

184180
try {
185181
writerContainer.driverSideSetup()
186-
sqlContext.sparkContext.runJob(partitionRDD.zip(dataRDD), writeRows _)
182+
df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
187183
writerContainer.commitJob()
188184
relation.refresh()
189185
} catch { case cause: Throwable =>
@@ -192,20 +188,44 @@ private[sql] case class InsertIntoFSBasedRelation(
192188
throw new SparkException("Job aborted.", cause)
193189
}
194190

195-
def writeRows(taskContext: TaskContext, iterator: Iterator[(Row, Row)]): Unit = {
191+
def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = {
196192
writerContainer.executorSideSetup(taskContext)
197193

198-
try {
194+
val partitionProj = newProjection(codegenEnabled, partitionOutput, output)
195+
val dataProj = newProjection(codegenEnabled, dataOutput, output)
196+
197+
if (needsConversion) {
198+
val converter = CatalystTypeConverters.createToScalaConverter(dataSchema)
199199
while (iterator.hasNext) {
200-
val (partitionValues, data) = iterator.next()
201-
writerContainer.outputWriterForRow(partitionValues).write(data)
200+
val row = converter(iterator.next()).asInstanceOf[Row]
201+
val partitionPart = partitionProj(row)
202+
val dataPart = dataProj(row)
203+
writerContainer.outputWriterForRow(partitionPart).write(dataPart)
204+
}
205+
} else {
206+
while (iterator.hasNext) {
207+
val row = iterator.next()
208+
val partitionPart = partitionProj(row)
209+
val dataPart = dataProj(row)
210+
writerContainer.outputWriterForRow(partitionPart).write(dataPart)
202211
}
203-
204-
writerContainer.commitTask()
205-
} catch { case cause: Throwable =>
206-
writerContainer.abortTask()
207-
throw new SparkException("Task failed while writing rows.", cause)
208212
}
213+
214+
writerContainer.commitTask()
215+
}
216+
}
217+
218+
// This is copied from SparkPlan, probably should move this to a more general place.
219+
private def newProjection(
220+
codegenEnabled: Boolean,
221+
expressions: Seq[Expression],
222+
inputSchema: Seq[Attribute]): Projection = {
223+
log.debug(
224+
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
225+
if (codegenEnabled) {
226+
GenerateProjection.generate(expressions, inputSchema)
227+
} else {
228+
new InterpretedProjection(expressions, inputSchema)
209229
}
210230
}
211231
}
@@ -379,6 +399,10 @@ private[sql] class DynamicPartitionWriterContainer(
379399
}
380400

381401
private[sql] object DynamicPartitionWriterContainer {
402+
//////////////////////////////////////////////////////////////////////////////////////////////////
403+
// The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils).
404+
//////////////////////////////////////////////////////////////////////////////////////////////////
405+
382406
val charToEscape = {
383407
val bitSet = new java.util.BitSet(128)
384408

sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,23 +61,19 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW
6161
}
6262

6363
class SimpleTextOutputWriter extends OutputWriter {
64-
private var converter: Any => Any = _
6564
private var recordWriter: RecordWriter[NullWritable, Text] = _
6665
private var taskAttemptContext: TaskAttemptContext = _
6766

6867
override def init(
6968
path: String,
7069
dataSchema: StructType,
7170
context: TaskAttemptContext): Unit = {
72-
converter = CatalystTypeConverters.createToScalaConverter(dataSchema)
7371
recordWriter = new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context)
7472
taskAttemptContext = context
7573
}
7674

7775
override def write(row: Row): Unit = {
78-
// Serializes values in `row` into a comma separated string
79-
val convertedRow = converter(row).asInstanceOf[Row]
80-
val serialized = convertedRow.toSeq.map(_.toString).mkString(",")
76+
val serialized = row.toSeq.map(_.toString).mkString(",")
8177
recordWriter.write(null, new Text(serialized))
8278
}
8379

0 commit comments

Comments
 (0)