Skip to content

Commit 8a8f9df

Browse files
committed
Add skeleton for GeneratedAggregate integration.
This typechecks properly and sketches how I'm intending to use row pointers and the hashmap. This has been a useful exercise for figuring out whether my interfaces will be sufficient.
1 parent 5d55cef commit 8a8f9df

File tree

4 files changed

+373
-0
lines changed

4 files changed

+373
-0
lines changed

sql/catalyst/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@
5050
<artifactId>spark-core_${scala.binary.version}</artifactId>
5151
<version>${project.version}</version>
5252
</dependency>
53+
<dependency>
54+
<groupId>org.apache.spark</groupId>
55+
<artifactId>spark-unsafe_${scala.binary.version}</artifactId>
56+
<version>${project.version}</version>
57+
</dependency>
5358
<dependency>
5459
<groupId>org.scalacheck</groupId>
5560
<artifactId>scalacheck_${scala.binary.version}</artifactId>

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434

3535
// TODO: pick a better name for this class, since this is potentially confusing.
36+
// Maybe call it UnsafeMutableRow?
3637

3738
/**
3839
* An Unsafe implementation of Row which is backed by raw memory instead of Java objets.
Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.rdd.RDD
22+
import org.apache.spark.sql.catalyst.trees._
23+
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.plans.physical._
25+
import org.apache.spark.sql.types._
26+
import org.apache.spark.unsafe.PlatformDependent
27+
import org.apache.spark.unsafe.array.ByteArrayMethods
28+
import org.apache.spark.unsafe.map.BytesToBytesMap
29+
import org.apache.spark.unsafe.memory.MemoryAllocator
30+
31+
// TODO: finish cleaning up documentation instead of just copying it
32+
33+
/**
34+
* TODO: copy of GeneratedAggregate that uses unsafe / offheap row implementations + hashtables.
35+
*/
36+
@DeveloperApi
37+
case class UnsafeGeneratedAggregate(
38+
partial: Boolean,
39+
groupingExpressions: Seq[Expression],
40+
aggregateExpressions: Seq[NamedExpression],
41+
child: SparkPlan)
42+
extends UnaryNode {
43+
44+
override def requiredChildDistribution: Seq[Distribution] =
45+
if (partial) {
46+
UnspecifiedDistribution :: Nil
47+
} else {
48+
if (groupingExpressions == Nil) {
49+
AllTuples :: Nil
50+
} else {
51+
ClusteredDistribution(groupingExpressions) :: Nil
52+
}
53+
}
54+
55+
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
56+
57+
override def execute(): RDD[Row] = {
58+
val aggregatesToCompute = aggregateExpressions.flatMap { a =>
59+
a.collect { case agg: AggregateExpression => agg}
60+
}
61+
62+
// If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite
63+
// (in test "aggregation with codegen").
64+
val computeFunctions = aggregatesToCompute.map {
65+
case c @ Count(expr) =>
66+
// If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
67+
// UnscaledValue will be null if and only if x is null; helps with Average on decimals
68+
val toCount = expr match {
69+
case UnscaledValue(e) => e
70+
case _ => expr
71+
}
72+
val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
73+
val initialValue = Literal(0L)
74+
val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount)
75+
val result = currentCount
76+
77+
AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
78+
79+
case s @ Sum(expr) =>
80+
val calcType =
81+
expr.dataType match {
82+
case DecimalType.Fixed(_, _) =>
83+
DecimalType.Unlimited
84+
case _ =>
85+
expr.dataType
86+
}
87+
88+
val currentSum = AttributeReference("currentSum", calcType, nullable = true)()
89+
val initialValue = Literal.create(null, calcType)
90+
91+
// Coalesce avoids double calculation...
92+
// but really, common sub expression elimination would be better....
93+
val zero = Cast(Literal(0), calcType)
94+
val updateFunction = Coalesce(
95+
Add(
96+
Coalesce(currentSum :: zero :: Nil),
97+
Cast(expr, calcType)
98+
) :: currentSum :: zero :: Nil)
99+
val result =
100+
expr.dataType match {
101+
case DecimalType.Fixed(_, _) =>
102+
Cast(currentSum, s.dataType)
103+
case _ => currentSum
104+
}
105+
106+
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
107+
108+
case cs @ CombineSum(expr) =>
109+
val calcType = expr.dataType
110+
expr.dataType match {
111+
case DecimalType.Fixed(_, _) =>
112+
DecimalType.Unlimited
113+
case _ =>
114+
expr.dataType
115+
}
116+
117+
val currentSum = AttributeReference("currentSum", calcType, nullable = true)()
118+
val initialValue = Literal.create(null, calcType)
119+
120+
// Coalasce avoids double calculation...
121+
// but really, common sub expression elimination would be better....
122+
val zero = Cast(Literal(0), calcType)
123+
// If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
124+
// UnscaledValue will be null if and only if x is null; helps with Average on decimals
125+
val actualExpr = expr match {
126+
case UnscaledValue(e) => e
127+
case _ => expr
128+
}
129+
// partial sum result can be null only when no input rows present
130+
val updateFunction = If(
131+
IsNotNull(actualExpr),
132+
Coalesce(
133+
Add(
134+
Coalesce(currentSum :: zero :: Nil),
135+
Cast(expr, calcType)) :: currentSum :: zero :: Nil),
136+
currentSum)
137+
138+
val result =
139+
expr.dataType match {
140+
case DecimalType.Fixed(_, _) =>
141+
Cast(currentSum, cs.dataType)
142+
case _ => currentSum
143+
}
144+
145+
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
146+
147+
case m @ Max(expr) =>
148+
val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)()
149+
val initialValue = Literal.create(null, expr.dataType)
150+
val updateMax = MaxOf(currentMax, expr)
151+
152+
AggregateEvaluation(
153+
currentMax :: Nil,
154+
initialValue :: Nil,
155+
updateMax :: Nil,
156+
currentMax)
157+
158+
case m @ Min(expr) =>
159+
val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)()
160+
val initialValue = Literal.create(null, expr.dataType)
161+
val updateMin = MinOf(currentMin, expr)
162+
163+
AggregateEvaluation(
164+
currentMin :: Nil,
165+
initialValue :: Nil,
166+
updateMin :: Nil,
167+
currentMin)
168+
169+
case CollectHashSet(Seq(expr)) =>
170+
val set =
171+
AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)()
172+
val initialValue = NewSet(expr.dataType)
173+
val addToSet = AddItemToSet(expr, set)
174+
175+
AggregateEvaluation(
176+
set :: Nil,
177+
initialValue :: Nil,
178+
addToSet :: Nil,
179+
set)
180+
181+
case CombineSetsAndCount(inputSet) =>
182+
val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType
183+
val set =
184+
AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)()
185+
val initialValue = NewSet(elementType)
186+
val collectSets = CombineSets(set, inputSet)
187+
188+
AggregateEvaluation(
189+
set :: Nil,
190+
initialValue :: Nil,
191+
collectSets :: Nil,
192+
CountSet(set))
193+
194+
case o => sys.error(s"$o can't be codegened.")
195+
}
196+
197+
val computationSchema = computeFunctions.flatMap(_.schema)
198+
199+
val resultMap: Map[TreeNodeRef, Expression] =
200+
aggregatesToCompute.zip(computeFunctions).map {
201+
case (agg, func) => new TreeNodeRef(agg) -> func.result
202+
}.toMap
203+
204+
val namedGroups = groupingExpressions.zipWithIndex.map {
205+
case (ne: NamedExpression, _) => (ne, ne)
206+
case (e, i) => (e, Alias(e, s"GroupingExpr$i")())
207+
}
208+
209+
val groupMap: Map[Expression, Attribute] =
210+
namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap
211+
212+
// The set of expressions that produce the final output given the aggregation buffer and the
213+
// grouping expressions.
214+
val resultExpressions = aggregateExpressions.map(_.transform {
215+
case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e))
216+
case e: Expression if groupMap.contains(e) => groupMap(e)
217+
})
218+
219+
child.execute().mapPartitions { iter =>
220+
// Builds a new custom class for holding the results of aggregation for a group.
221+
val initialValues = computeFunctions.flatMap(_.initialValues)
222+
val newAggregationBuffer = newProjection(initialValues, child.output)
223+
log.info(s"Initial values: ${initialValues.mkString(",")}")
224+
225+
// A projection that computes the group given an input tuple.
226+
val groupProjection = newProjection(groupingExpressions, child.output)
227+
log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}")
228+
229+
// A projection that is used to update the aggregate values for a group given a new tuple.
230+
// This projection should be targeted at the current values for the group and then applied
231+
// to a joined row of the current values with the new input row.
232+
val updateExpressions = computeFunctions.flatMap(_.update)
233+
val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output
234+
val updateProjection = newMutableProjection(updateExpressions, updateSchema)()
235+
log.info(s"Update Expressions: ${updateExpressions.mkString(",")}")
236+
237+
// A projection that produces the final result, given a computation.
238+
val resultProjectionBuilder =
239+
newMutableProjection(
240+
resultExpressions,
241+
(namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq)
242+
log.info(s"Result Projection: ${resultExpressions.mkString(",")}")
243+
244+
val joinedRow = new JoinedRow3
245+
246+
if (groupingExpressions.isEmpty) {
247+
// TODO: Codegening anything other than the updateProjection is probably over kill.
248+
val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
249+
var currentRow: Row = null
250+
updateProjection.target(buffer)
251+
252+
while (iter.hasNext) {
253+
currentRow = iter.next()
254+
updateProjection(joinedRow(buffer, currentRow))
255+
}
256+
257+
val resultProjection = resultProjectionBuilder()
258+
Iterator(resultProjection(buffer))
259+
} else {
260+
// TODO: if we knew how many groups to expect, we could size this hashmap appropriately
261+
val buffers = new BytesToBytesMap(MemoryAllocator.HEAP, 128)
262+
263+
// Set up the mutable "pointers" that we'll re-use when pointing to key and value rows
264+
val keyPointer: UnsafeRow = new UnsafeRow()
265+
val currentBuffer: UnsafeRow = new UnsafeRow()
266+
267+
// We're going to need to allocate a lot of empty aggregation buffers, so let's do it
268+
// once and keep a copy of the serialized buffer and copy it into the hash map when we see
269+
// new keys:
270+
val javaAggregationBuffer: MutableRow =
271+
newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
272+
val numberOfFieldsInAggregationBuffer: Int = javaAggregationBuffer.schema.fields.length
273+
val aggregationBufferSchema: StructType = javaAggregationBuffer.schema
274+
// TODO perform that conversion to an UnsafeRow
275+
// Allocate some scratch space for holding the keys that we use to index into the hash map.
276+
val unsafeRowBuffer: Array[Long] = new Array[Long](1024)
277+
278+
// TODO: there's got got to be an actual way of obtaining this up front.
279+
var groupProjectionSchema: StructType = null
280+
281+
while (iter.hasNext) {
282+
// Zero out the buffer that's used to hold the current row. This is necessary in order
283+
// to ensure that rows hash properly, since garbage data from the previous row could
284+
// otherwise end up as padding in this row.
285+
ByteArrayMethods.zeroBytes(
286+
unsafeRowBuffer, PlatformDependent.LONG_ARRAY_OFFSET, unsafeRowBuffer.length)
287+
// Grab the next row from our input iterator and compute its group projection.
288+
// In the long run, it might be nice to use Unsafe rows for this as well, but for now
289+
// we'll just rely on the existing code paths to compute the projection.
290+
val currentJavaRow = iter.next()
291+
val currentGroup: Row = groupProjection(currentJavaRow)
292+
// Convert the current group into an UnsafeRow so that we can use it as a key for our
293+
// aggregation hash map
294+
// --- TODO ---
295+
val keyLengthInBytes: Int = 0
296+
val loc: BytesToBytesMap#Location =
297+
buffers.lookup(unsafeRowBuffer, PlatformDependent.LONG_ARRAY_OFFSET, keyLengthInBytes)
298+
if (!loc.isDefined) {
299+
// This is the first time that we've seen this key, so we'll copy the empty aggregation
300+
// buffer row that we created earlier. TODO: this doesn't work very well for aggregates
301+
// where the size of the aggregate buffer is different for different rows (even if the
302+
// size of buffers don't grow once created, as is the case for things like grabbing the
303+
// first row's value for a string-valued column (or the shortest string)).
304+
305+
// TODO
306+
307+
loc.storeKeyAndValue(
308+
unsafeRowBuffer,
309+
PlatformDependent.LONG_ARRAY_OFFSET,
310+
keyLengthInBytes,
311+
null, // empty agg buffer
312+
PlatformDependent.LONG_ARRAY_OFFSET,
313+
0 // length of the aggregation buffer
314+
)
315+
}
316+
// Reset our pointer to point to the buffer stored in the hash map
317+
val address = loc.getValueAddress
318+
currentBuffer.set(
319+
address.getBaseObject,
320+
address.getBaseOffset,
321+
numberOfFieldsInAggregationBuffer,
322+
javaAggregationBuffer.schema
323+
)
324+
// Target the projection at the current aggregation buffer and then project the updated
325+
// values.
326+
updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentJavaRow))
327+
}
328+
329+
new Iterator[Row] {
330+
private[this] val resultIterator = buffers.iterator()
331+
private[this] val resultProjection = resultProjectionBuilder()
332+
private[this] val key: UnsafeRow = new UnsafeRow()
333+
private[this] val value: UnsafeRow = new UnsafeRow()
334+
335+
def hasNext: Boolean = resultIterator.hasNext
336+
337+
def next(): Row = {
338+
val currentGroup: BytesToBytesMap#Location = resultIterator.next()
339+
val keyAddress = currentGroup.getKeyAddress
340+
key.set(
341+
keyAddress.getBaseObject,
342+
keyAddress.getBaseOffset,
343+
groupProjectionSchema.fields.length,
344+
groupProjectionSchema)
345+
val valueAddress = currentGroup.getValueAddress
346+
value.set(
347+
valueAddress.getBaseObject,
348+
valueAddress.getBaseOffset,
349+
aggregationBufferSchema.fields.length,
350+
aggregationBufferSchema
351+
)
352+
resultProjection(joinedRow(key, value))
353+
}
354+
}
355+
}
356+
}
357+
}
358+
}

unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ private ByteArrayMethods() {
3131
// Private constructor, since this class only contains static methods.
3232
}
3333

34+
public static void zeroBytes(
35+
Object baseObject,
36+
long baseOffset,
37+
long lengthInBytes) {
38+
for (int i = 0; i < lengthInBytes; i++) {
39+
PlatformDependent.UNSAFE.putByte(baseObject, baseOffset + i, (byte) 0);
40+
}
41+
}
42+
3443
/**
3544
* Optimized equality check for equal-length byte arrays.
3645
* @return true if the arrays are equal, false otherwise

0 commit comments

Comments
 (0)