Skip to content

Commit fc4c3a8

Browse files
committed
Sketch how the converters will be used in UnsafeGeneratedAggregate
1 parent 53ba9b7 commit fc4c3a8

File tree

1 file changed

+32
-18
lines changed

1 file changed

+32
-18
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeGeneratedAggregate.scala

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ import org.apache.spark.unsafe.memory.MemoryAllocator
3535
*/
3636
@DeveloperApi
3737
case class UnsafeGeneratedAggregate(
38-
partial: Boolean,
39-
groupingExpressions: Seq[Expression],
40-
aggregateExpressions: Seq[NamedExpression],
41-
child: SparkPlan)
38+
partial: Boolean,
39+
groupingExpressions: Seq[Expression],
40+
aggregateExpressions: Seq[NamedExpression],
41+
child: SparkPlan)
4242
extends UnaryNode {
4343

4444
override def requiredChildDistribution: Seq[Distribution] =
@@ -267,17 +267,25 @@ case class UnsafeGeneratedAggregate(
267267
// We're going to need to allocate a lot of empty aggregation buffers, so let's do it
268268
// once and keep a copy of the serialized buffer and copy it into the hash map when we see
269269
// new keys:
270-
val javaAggregationBuffer: MutableRow =
271-
newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
272-
val numberOfFieldsInAggregationBuffer: Int = javaAggregationBuffer.schema.fields.length
273-
val aggregationBufferSchema: StructType = javaAggregationBuffer.schema
274-
// TODO perform that conversion to an UnsafeRow
275-
// Allocate some scratch space for holding the keys that we use to index into the hash map.
276-
val unsafeRowBuffer: Array[Long] = new Array[Long](1024)
270+
val (emptyAggregationBuffer: Array[Long], numberOfColumnsInAggBuffer: Int) = {
271+
val javaBuffer: MutableRow = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
272+
val converter = new UnsafeRowConverter(javaBuffer.schema.fields.map(_.dataType))
273+
val buffer = new Array[Long](converter.getSizeRequirement(javaBuffer))
274+
converter.writeRow(javaBuffer, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
275+
(buffer, javaBuffer.schema.fields.length)
276+
}
277277

278278
// TODO: there's got got to be an actual way of obtaining this up front.
279279
var groupProjectionSchema: StructType = null
280280

281+
val keyToUnsafeRowConverter: UnsafeRowConverter = {
282+
new UnsafeRowConverter(groupProjectionSchema.fields.map(_.dataType))
283+
}
284+
285+
// Allocate some scratch space for holding the keys that we use to index into the hash map.
286+
// 16 MB ought to be enough for anyone (TODO)
287+
val unsafeRowBuffer: Array[Long] = new Array[Long](1024 * 16 / 8)
288+
281289
while (iter.hasNext) {
282290
// Zero out the buffer that's used to hold the current row. This is necessary in order
283291
// to ensure that rows hash properly, since garbage data from the previous row could
@@ -291,7 +299,13 @@ case class UnsafeGeneratedAggregate(
291299
val currentGroup: Row = groupProjection(currentJavaRow)
292300
// Convert the current group into an UnsafeRow so that we can use it as a key for our
293301
// aggregation hash map
294-
// --- TODO ---
302+
val groupProjectionSize = keyToUnsafeRowConverter.getSizeRequirement(currentGroup)
303+
if (groupProjectionSize > unsafeRowBuffer.length) {
304+
throw new IllegalStateException("Group projection does not fit into buffer")
305+
}
306+
keyToUnsafeRowConverter.writeRow(
307+
currentGroup, unsafeRowBuffer, PlatformDependent.LONG_ARRAY_OFFSET)
308+
295309
val keyLengthInBytes: Int = 0
296310
val loc: BytesToBytesMap#Location =
297311
buffers.lookup(unsafeRowBuffer, PlatformDependent.LONG_ARRAY_OFFSET, keyLengthInBytes)
@@ -308,18 +322,18 @@ case class UnsafeGeneratedAggregate(
308322
unsafeRowBuffer,
309323
PlatformDependent.LONG_ARRAY_OFFSET,
310324
keyLengthInBytes,
311-
null, // empty agg buffer
325+
emptyAggregationBuffer,
312326
PlatformDependent.LONG_ARRAY_OFFSET,
313-
0 // length of the aggregation buffer
327+
emptyAggregationBuffer.length
314328
)
315329
}
316330
// Reset our pointer to point to the buffer stored in the hash map
317331
val address = loc.getValueAddress
318332
currentBuffer.set(
319333
address.getBaseObject,
320334
address.getBaseOffset,
321-
numberOfFieldsInAggregationBuffer,
322-
javaAggregationBuffer.schema
335+
numberOfColumnsInAggBuffer,
336+
null
323337
)
324338
// Target the projection at the current aggregation buffer and then project the updated
325339
// values.
@@ -346,8 +360,8 @@ case class UnsafeGeneratedAggregate(
346360
value.set(
347361
valueAddress.getBaseObject,
348362
valueAddress.getBaseOffset,
349-
aggregationBufferSchema.fields.length,
350-
aggregationBufferSchema
363+
numberOfColumnsInAggBuffer,
364+
null
351365
)
352366
// TODO: once the iterator has been fully consumed, we need to free the map so that
353367
// its off-heap memory is reclaimed. This may mean that we'll have to perform an extra

0 commit comments

Comments
 (0)