Skip to content

Commit 79ec072

Browse files
JoshRosenrxin
authored andcommitted
[SPARK-9023] [SQL] Efficiency improvements for UnsafeRows in Exchange
This pull request aims to improve the performance of SQL's Exchange operator when shuffling UnsafeRows. It also makes several general efficiency improvements to Exchange. Key changes: - When performing hash partitioning, the old Exchange projected the partitioning columns into a new row then passed a `(partitioningColumRow: InternalRow, row: InternalRow)` pair into the shuffle. This is very inefficient because it ends up redundantly serializing the partitioning columns only to immediately discard them after the shuffle. After this patch's changes, Exchange now shuffles `(partitionId: Int, row: InternalRow)` pairs. This still isn't optimal, since we're still shuffling extra data that we don't need, but it's significantly more efficient than the old implementation; in the future, we may be able to further optimize this once we implement a new shuffle write interface that accepts non-key-value-pair inputs. - Exchange's `compute()` method has been significantly simplified; the new code has less duplication and thus is easier to understand. - When the Exchange's input operator produces UnsafeRows, Exchange will use a specialized `UnsafeRowSerializer` to serialize these rows. This serializer is significantly more efficient since it simply copies the UnsafeRow's underlying bytes. Note that this approach does not work for UnsafeRows that use the ObjectPool mechanism; I did not add support for this because we are planning to remove ObjectPool in the next few weeks. Author: Josh Rosen <[email protected]> Closes apache#7456 from JoshRosen/unsafe-exchange and squashes the following commits: 7e75259 [Josh Rosen] Fix cast in SparkSqlSerializer2Suite 0082515 [Josh Rosen] Some additional comments + small cleanup to remove an unused parameter a27cfc1 [Josh Rosen] Add missing newline 741973c [Josh Rosen] Add simple test of UnsafeRow shuffling in Exchange. 359c6a4 [Josh Rosen] Remove println() and add comments 93904e7 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-exchange 8dd3ff2 [Josh Rosen] Exchange outputs UnsafeRows when its child outputs them dd9c66d [Josh Rosen] Fix for copying logic 035af21 [Josh Rosen] Add logic for choosing when to use UnsafeRowSerializer 7876f31 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-shuffle cbea80b [Josh Rosen] Add UnsafeRowSerializer 0f2ac86 [Josh Rosen] Import ordering 3ca8515 [Josh Rosen] Big code simplification in Exchange 3526868 [Josh Rosen] Iniitial cut at removing shuffle on KV pairs
1 parent 972d890 commit 79ec072

File tree

8 files changed

+398
-116
lines changed

8 files changed

+398
-116
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 49 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.annotation.DeveloperApi
21-
import org.apache.spark.rdd.{RDD, ShuffledRDD}
21+
import org.apache.spark.rdd.RDD
2222
import org.apache.spark.serializer.Serializer
2323
import org.apache.spark.shuffle.hash.HashShuffleManager
2424
import org.apache.spark.shuffle.sort.SortShuffleManager
@@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.errors.attachTree
2929
import org.apache.spark.sql.catalyst.expressions._
3030
import org.apache.spark.sql.catalyst.plans.physical._
3131
import org.apache.spark.sql.catalyst.rules.Rule
32-
import org.apache.spark.sql.types.DataType
3332
import org.apache.spark.util.MutablePair
3433
import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
3534

@@ -44,6 +43,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
4443

4544
override def output: Seq[Attribute] = child.output
4645

46+
override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
47+
48+
override def canProcessSafeRows: Boolean = true
49+
50+
override def canProcessUnsafeRows: Boolean = true
51+
4752
/**
4853
* Determines whether records must be defensively copied before being sent to the shuffle.
4954
* Several of Spark's shuffle components will buffer deserialized Java objects in memory. The
@@ -112,109 +117,70 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
112117

113118
@transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf
114119

115-
private def getSerializer(
116-
keySchema: Array[DataType],
117-
valueSchema: Array[DataType],
118-
numPartitions: Int): Serializer = {
120+
private val serializer: Serializer = {
121+
val rowDataTypes = child.output.map(_.dataType).toArray
119122
// It is true when there is no field that needs to be write out.
120123
// For now, we will not use SparkSqlSerializer2 when noField is true.
121-
val noField =
122-
(keySchema == null || keySchema.length == 0) &&
123-
(valueSchema == null || valueSchema.length == 0)
124+
val noField = rowDataTypes == null || rowDataTypes.length == 0
124125

125126
val useSqlSerializer2 =
126127
child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled.
127-
SparkSqlSerializer2.support(keySchema) && // The schema of key is supported.
128-
SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported.
128+
SparkSqlSerializer2.support(rowDataTypes) && // The schema of row is supported.
129129
!noField
130130

131-
val serializer = if (useSqlSerializer2) {
131+
if (child.outputsUnsafeRows) {
132+
logInfo("Using UnsafeRowSerializer.")
133+
new UnsafeRowSerializer(child.output.size)
134+
} else if (useSqlSerializer2) {
132135
logInfo("Using SparkSqlSerializer2.")
133-
new SparkSqlSerializer2(keySchema, valueSchema)
136+
new SparkSqlSerializer2(rowDataTypes)
134137
} else {
135138
logInfo("Using SparkSqlSerializer.")
136139
new SparkSqlSerializer(sparkConf)
137140
}
138-
139-
serializer
140141
}
141142

142143
protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") {
143-
newPartitioning match {
144-
case HashPartitioning(expressions, numPartitions) =>
145-
val keySchema = expressions.map(_.dataType).toArray
146-
val valueSchema = child.output.map(_.dataType).toArray
147-
val serializer = getSerializer(keySchema, valueSchema, numPartitions)
148-
val part = new HashPartitioner(numPartitions)
149-
150-
val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) {
151-
child.execute().mapPartitions { iter =>
152-
val hashExpressions = newMutableProjection(expressions, child.output)()
153-
iter.map(r => (hashExpressions(r).copy(), r.copy()))
154-
}
155-
} else {
156-
child.execute().mapPartitions { iter =>
157-
val hashExpressions = newMutableProjection(expressions, child.output)()
158-
val mutablePair = new MutablePair[InternalRow, InternalRow]()
159-
iter.map(r => mutablePair.update(hashExpressions(r), r))
160-
}
161-
}
162-
val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part)
163-
shuffled.setSerializer(serializer)
164-
shuffled.map(_._2)
165-
144+
val rdd = child.execute()
145+
val part: Partitioner = newPartitioning match {
146+
case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions)
166147
case RangePartitioning(sortingExpressions, numPartitions) =>
167-
val keySchema = child.output.map(_.dataType).toArray
168-
val serializer = getSerializer(keySchema, null, numPartitions)
169-
170-
val childRdd = child.execute()
171-
val part: Partitioner = {
172-
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
173-
// partition bounds. To get accurate samples, we need to copy the mutable keys.
174-
val rddForSampling = childRdd.mapPartitions { iter =>
175-
val mutablePair = new MutablePair[InternalRow, Null]()
176-
iter.map(row => mutablePair.update(row.copy(), null))
177-
}
178-
// TODO: RangePartitioner should take an Ordering.
179-
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
180-
new RangePartitioner(numPartitions, rddForSampling, ascending = true)
181-
}
182-
183-
val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) {
184-
childRdd.mapPartitions { iter => iter.map(row => (row.copy(), null))}
185-
} else {
186-
childRdd.mapPartitions { iter =>
187-
val mutablePair = new MutablePair[InternalRow, Null]()
188-
iter.map(row => mutablePair.update(row, null))
189-
}
148+
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
149+
// partition bounds. To get accurate samples, we need to copy the mutable keys.
150+
val rddForSampling = rdd.mapPartitions { iter =>
151+
val mutablePair = new MutablePair[InternalRow, Null]()
152+
iter.map(row => mutablePair.update(row.copy(), null))
190153
}
191-
192-
val shuffled = new ShuffledRDD[InternalRow, Null, Null](rdd, part)
193-
shuffled.setSerializer(serializer)
194-
shuffled.map(_._1)
195-
154+
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
155+
new RangePartitioner(numPartitions, rddForSampling, ascending = true)
196156
case SinglePartition =>
197-
val valueSchema = child.output.map(_.dataType).toArray
198-
val serializer = getSerializer(null, valueSchema, numPartitions = 1)
199-
val partitioner = new HashPartitioner(1)
200-
201-
val rdd = if (needToCopyObjectsBeforeShuffle(partitioner, serializer)) {
202-
child.execute().mapPartitions {
203-
iter => iter.map(r => (null, r.copy()))
204-
}
205-
} else {
206-
child.execute().mapPartitions { iter =>
207-
val mutablePair = new MutablePair[Null, InternalRow]()
208-
iter.map(r => mutablePair.update(null, r))
209-
}
157+
new Partitioner {
158+
override def numPartitions: Int = 1
159+
override def getPartition(key: Any): Int = 0
210160
}
211-
val shuffled = new ShuffledRDD[Null, InternalRow, InternalRow](rdd, partitioner)
212-
shuffled.setSerializer(serializer)
213-
shuffled.map(_._2)
214-
215161
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
216162
// TODO: Handle BroadcastPartitioning.
217163
}
164+
def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match {
165+
case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)()
166+
case RangePartitioning(_, _) | SinglePartition => identity
167+
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
168+
}
169+
val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
170+
if (needToCopyObjectsBeforeShuffle(part, serializer)) {
171+
rdd.mapPartitions { iter =>
172+
val getPartitionKey = getPartitionKeyExtractor()
173+
iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) }
174+
}
175+
} else {
176+
rdd.mapPartitions { iter =>
177+
val getPartitionKey = getPartitionKeyExtractor()
178+
val mutablePair = new MutablePair[Int, InternalRow]()
179+
iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) }
180+
}
181+
}
182+
}
183+
new ShuffledRowRDD(rddWithPartitionIds, serializer, part.numPartitions)
218184
}
219185
}
220186

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import org.apache.spark._
21+
import org.apache.spark.rdd.RDD
22+
import org.apache.spark.serializer.Serializer
23+
import org.apache.spark.sql.catalyst.InternalRow
24+
import org.apache.spark.sql.types.DataType
25+
26+
private class ShuffledRowRDDPartition(val idx: Int) extends Partition {
27+
override val index: Int = idx
28+
override def hashCode(): Int = idx
29+
}
30+
31+
/**
32+
* A dummy partitioner for use with records whose partition ids have been pre-computed (i.e. for
33+
* use on RDDs of (Int, Row) pairs where the Int is a partition id in the expected range).
34+
*/
35+
private class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner {
36+
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
37+
}
38+
39+
/**
40+
* This is a specialized version of [[org.apache.spark.rdd.ShuffledRDD]] that is optimized for
41+
* shuffling rows instead of Java key-value pairs. Note that something like this should eventually
42+
* be implemented in Spark core, but that is blocked by some more general refactorings to shuffle
43+
* interfaces / internals.
44+
*
45+
* @param prev the RDD being shuffled. Elements of this RDD are (partitionId, Row) pairs.
46+
* Partition ids should be in the range [0, numPartitions - 1].
47+
* @param serializer the serializer used during the shuffle.
48+
* @param numPartitions the number of post-shuffle partitions.
49+
*/
50+
class ShuffledRowRDD(
51+
@transient var prev: RDD[Product2[Int, InternalRow]],
52+
serializer: Serializer,
53+
numPartitions: Int)
54+
extends RDD[InternalRow](prev.context, Nil) {
55+
56+
private val part: Partitioner = new PartitionIdPassthrough(numPartitions)
57+
58+
override def getDependencies: Seq[Dependency[_]] = {
59+
List(new ShuffleDependency[Int, InternalRow, InternalRow](prev, part, Some(serializer)))
60+
}
61+
62+
override val partitioner = Some(part)
63+
64+
override def getPartitions: Array[Partition] = {
65+
Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRowRDDPartition(i))
66+
}
67+
68+
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
69+
val dep = dependencies.head.asInstanceOf[ShuffleDependency[Int, InternalRow, InternalRow]]
70+
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
71+
.read()
72+
.asInstanceOf[Iterator[Product2[Int, InternalRow]]]
73+
.map(_._2)
74+
}
75+
76+
override def clearDependencies() {
77+
super.clearDependencies()
78+
prev = null
79+
}
80+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,12 @@ import org.apache.spark.unsafe.types.UTF8String
4545
* the comment of the `serializer` method in [[Exchange]] for more information on it.
4646
*/
4747
private[sql] class Serializer2SerializationStream(
48-
keySchema: Array[DataType],
49-
valueSchema: Array[DataType],
48+
rowSchema: Array[DataType],
5049
out: OutputStream)
5150
extends SerializationStream with Logging {
5251

5352
private val rowOut = new DataOutputStream(new BufferedOutputStream(out))
54-
private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
55-
private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
53+
private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut)
5654

5755
override def writeObject[T: ClassTag](t: T): SerializationStream = {
5856
val kv = t.asInstanceOf[Product2[Row, Row]]
@@ -63,12 +61,12 @@ private[sql] class Serializer2SerializationStream(
6361
}
6462

6563
override def writeKey[T: ClassTag](t: T): SerializationStream = {
66-
writeKeyFunc(t.asInstanceOf[Row])
64+
// No-op.
6765
this
6866
}
6967

7068
override def writeValue[T: ClassTag](t: T): SerializationStream = {
71-
writeValueFunc(t.asInstanceOf[Row])
69+
writeRowFunc(t.asInstanceOf[Row])
7270
this
7371
}
7472

@@ -85,8 +83,7 @@ private[sql] class Serializer2SerializationStream(
8583
* The corresponding deserialization stream for [[Serializer2SerializationStream]].
8684
*/
8785
private[sql] class Serializer2DeserializationStream(
88-
keySchema: Array[DataType],
89-
valueSchema: Array[DataType],
86+
rowSchema: Array[DataType],
9087
in: InputStream)
9188
extends DeserializationStream with Logging {
9289

@@ -103,22 +100,20 @@ private[sql] class Serializer2DeserializationStream(
103100
}
104101

105102
// Functions used to return rows for key and value.
106-
private val getKey = rowGenerator(keySchema)
107-
private val getValue = rowGenerator(valueSchema)
103+
private val getRow = rowGenerator(rowSchema)
108104
// Functions used to read a serialized row from the InputStream and deserialize it.
109-
private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn)
110-
private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn)
105+
private val readRowFunc = SparkSqlSerializer2.createDeserializationFunction(rowSchema, rowIn)
111106

112107
override def readObject[T: ClassTag](): T = {
113-
(readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T]
108+
readValue()
114109
}
115110

116111
override def readKey[T: ClassTag](): T = {
117-
readKeyFunc(getKey()).asInstanceOf[T]
112+
null.asInstanceOf[T] // intentionally left blank.
118113
}
119114

120115
override def readValue[T: ClassTag](): T = {
121-
readValueFunc(getValue()).asInstanceOf[T]
116+
readRowFunc(getRow()).asInstanceOf[T]
122117
}
123118

124119
override def close(): Unit = {
@@ -127,8 +122,7 @@ private[sql] class Serializer2DeserializationStream(
127122
}
128123

129124
private[sql] class SparkSqlSerializer2Instance(
130-
keySchema: Array[DataType],
131-
valueSchema: Array[DataType])
125+
rowSchema: Array[DataType])
132126
extends SerializerInstance {
133127

134128
def serialize[T: ClassTag](t: T): ByteBuffer =
@@ -141,30 +135,25 @@ private[sql] class SparkSqlSerializer2Instance(
141135
throw new UnsupportedOperationException("Not supported.")
142136

143137
def serializeStream(s: OutputStream): SerializationStream = {
144-
new Serializer2SerializationStream(keySchema, valueSchema, s)
138+
new Serializer2SerializationStream(rowSchema, s)
145139
}
146140

147141
def deserializeStream(s: InputStream): DeserializationStream = {
148-
new Serializer2DeserializationStream(keySchema, valueSchema, s)
142+
new Serializer2DeserializationStream(rowSchema, s)
149143
}
150144
}
151145

152146
/**
153147
* SparkSqlSerializer2 is a special serializer that creates serialization function and
154148
* deserialization function based on the schema of data. It assumes that values passed in
155-
* are key/value pairs and values returned from it are also key/value pairs.
156-
* The schema of keys is represented by `keySchema` and that of values is represented by
157-
* `valueSchema`.
149+
* are Rows.
158150
*/
159-
private[sql] class SparkSqlSerializer2(
160-
keySchema: Array[DataType],
161-
valueSchema: Array[DataType])
151+
private[sql] class SparkSqlSerializer2(rowSchema: Array[DataType])
162152
extends Serializer
163153
with Logging
164154
with Serializable{
165155

166-
def newInstance(): SerializerInstance =
167-
new SparkSqlSerializer2Instance(keySchema, valueSchema)
156+
def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(rowSchema)
168157

169158
override def supportsRelocationOfSerializedObjects: Boolean = {
170159
// SparkSqlSerializer2 is stateless and writes no stream headers

0 commit comments

Comments
 (0)