@@ -35,10 +35,10 @@ import org.apache.spark.unsafe.memory.MemoryAllocator
3535 */
3636@ DeveloperApi
3737case class UnsafeGeneratedAggregate (
38- partial : Boolean ,
39- groupingExpressions : Seq [Expression ],
40- aggregateExpressions : Seq [NamedExpression ],
41- child : SparkPlan )
38+ partial : Boolean ,
39+ groupingExpressions : Seq [Expression ],
40+ aggregateExpressions : Seq [NamedExpression ],
41+ child : SparkPlan )
4242 extends UnaryNode {
4343
4444 override def requiredChildDistribution : Seq [Distribution ] =
@@ -267,17 +267,25 @@ case class UnsafeGeneratedAggregate(
267267 // We're going to need to allocate a lot of empty aggregation buffers, so let's do it
268268 // once and keep a copy of the serialized buffer and copy it into the hash map when we see
269269 // new keys:
270- val javaAggregationBuffer : MutableRow =
271- newAggregationBuffer(EmptyRow ).asInstanceOf [MutableRow ]
272- val numberOfFieldsInAggregationBuffer : Int = javaAggregationBuffer .schema.fields.length
273- val aggregationBufferSchema : StructType = javaAggregationBuffer.schema
274- // TODO perform that conversion to an UnsafeRow
275- // Allocate some scratch space for holding the keys that we use to index into the hash map.
276- val unsafeRowBuffer : Array [ Long ] = new Array [ Long ]( 1024 )
270+ val ( emptyAggregationBuffer : Array [ Long ], numberOfColumnsInAggBuffer : Int ) = {
271+ val javaBuffer : MutableRow = newAggregationBuffer(EmptyRow ).asInstanceOf [MutableRow ]
272+ val converter = new UnsafeRowConverter (javaBuffer .schema.fields.map(_.dataType))
273+ val buffer = new Array [ Long ](converter.getSizeRequirement(javaBuffer))
274+ converter.writeRow(javaBuffer, buffer, PlatformDependent . LONG_ARRAY_OFFSET )
275+ (buffer, javaBuffer.schema.fields.length)
276+ }
277277
278278 // TODO: there's got got to be an actual way of obtaining this up front.
279279 var groupProjectionSchema : StructType = null
280280
281+ val keyToUnsafeRowConverter : UnsafeRowConverter = {
282+ new UnsafeRowConverter (groupProjectionSchema.fields.map(_.dataType))
283+ }
284+
285+ // Allocate some scratch space for holding the keys that we use to index into the hash map.
286+ // 16 MB ought to be enough for anyone (TODO)
287+ val unsafeRowBuffer : Array [Long ] = new Array [Long ](1024 * 16 / 8 )
288+
281289 while (iter.hasNext) {
282290 // Zero out the buffer that's used to hold the current row. This is necessary in order
283291 // to ensure that rows hash properly, since garbage data from the previous row could
@@ -291,7 +299,13 @@ case class UnsafeGeneratedAggregate(
291299 val currentGroup : Row = groupProjection(currentJavaRow)
292300 // Convert the current group into an UnsafeRow so that we can use it as a key for our
293301 // aggregation hash map
294- // --- TODO ---
302+ val groupProjectionSize = keyToUnsafeRowConverter.getSizeRequirement(currentGroup)
303+ if (groupProjectionSize > unsafeRowBuffer.length) {
304+ throw new IllegalStateException (" Group projection does not fit into buffer" )
305+ }
306+ keyToUnsafeRowConverter.writeRow(
307+ currentGroup, unsafeRowBuffer, PlatformDependent .LONG_ARRAY_OFFSET )
308+
295309 val keyLengthInBytes : Int = 0
296310 val loc : BytesToBytesMap # Location =
297311 buffers.lookup(unsafeRowBuffer, PlatformDependent .LONG_ARRAY_OFFSET , keyLengthInBytes)
@@ -308,18 +322,18 @@ case class UnsafeGeneratedAggregate(
308322 unsafeRowBuffer,
309323 PlatformDependent .LONG_ARRAY_OFFSET ,
310324 keyLengthInBytes,
311- null , // empty agg buffer
325+ emptyAggregationBuffer,
312326 PlatformDependent .LONG_ARRAY_OFFSET ,
313- 0 // length of the aggregation buffer
327+ emptyAggregationBuffer. length
314328 )
315329 }
316330 // Reset our pointer to point to the buffer stored in the hash map
317331 val address = loc.getValueAddress
318332 currentBuffer.set(
319333 address.getBaseObject,
320334 address.getBaseOffset,
321- numberOfFieldsInAggregationBuffer ,
322- javaAggregationBuffer.schema
335+ numberOfColumnsInAggBuffer ,
336+ null
323337 )
324338 // Target the projection at the current aggregation buffer and then project the updated
325339 // values.
@@ -346,8 +360,8 @@ case class UnsafeGeneratedAggregate(
346360 value.set(
347361 valueAddress.getBaseObject,
348362 valueAddress.getBaseOffset,
349- aggregationBufferSchema.fields.length ,
350- aggregationBufferSchema
363+ numberOfColumnsInAggBuffer ,
364+ null
351365 )
352366 // TODO: once the iterator has been fully consumed, we need to free the map so that
353367 // its off-heap memory is reclaimed. This may mean that we'll have to perform an extra
0 commit comments