1919
2020import java .util .Iterator ;
2121
22- import scala .Function1 ;
23-
2422import org .apache .spark .sql .catalyst .InternalRow ;
25- import org .apache .spark .sql .catalyst . util . ObjectPool ;
26- import org .apache .spark .sql .catalyst . util . UniqueObjectPool ;
23+ import org .apache .spark .sql .types . StructField ;
24+ import org .apache .spark .sql .types . StructType ;
2725import org .apache .spark .unsafe .PlatformDependent ;
2826import org .apache .spark .unsafe .map .BytesToBytesMap ;
2927import org .apache .spark .unsafe .memory .MemoryLocation ;
@@ -40,48 +38,26 @@ public final class UnsafeFixedWidthAggregationMap {
4038 * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the
4139 * map, we copy this buffer and use it as the value.
4240 */
43- private final byte [] emptyBuffer ;
41+ private final byte [] emptyAggregationBuffer ;
4442
45- /**
46- * An empty row used by `initProjection`
47- */
48- private static final InternalRow emptyRow = new GenericInternalRow ();
43+ private final StructType aggregationBufferSchema ;
4944
50- /**
51- * Whether can the empty aggregation buffer be reuse without calling `initProjection` or not.
52- */
53- private final boolean reuseEmptyBuffer ;
45+ private final StructType groupingKeySchema ;
5446
5547 /**
56- * The projection used to initialize the emptyBuffer
48+ * Encodes grouping keys as UnsafeRows.
5749 */
58- private final Function1 <InternalRow , InternalRow > initProjection ;
59-
60- /**
61- * Encodes grouping keys or buffers as UnsafeRows.
62- */
63- private final UnsafeRowConverter keyConverter ;
64- private final UnsafeRowConverter bufferConverter ;
50+ private final UnsafeRowConverter groupingKeyToUnsafeRowConverter ;
6551
6652 /**
6753 * A hashmap which maps from opaque bytearray keys to bytearray values.
6854 */
6955 private final BytesToBytesMap map ;
7056
71- /**
72- * An object pool for objects that are used in grouping keys.
73- */
74- private final UniqueObjectPool keyPool ;
75-
76- /**
77- * An object pool for objects that are used in aggregation buffers.
78- */
79- private final ObjectPool bufferPool ;
80-
8157 /**
8258 * Re-used pointer to the current aggregation buffer
8359 */
84- private final UnsafeRow currentBuffer = new UnsafeRow ();
60+ private final UnsafeRow currentAggregationBuffer = new UnsafeRow ();
8561
8662 /**
8763 * Scratch space that is used when encoding grouping keys into UnsafeRow format.
@@ -93,59 +69,86 @@ public final class UnsafeFixedWidthAggregationMap {
9369
9470 private final boolean enablePerfMetrics ;
9571
72+ /**
73+ * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema,
74+ * false otherwise.
75+ */
76+ public static boolean supportsGroupKeySchema (StructType schema ) {
77+ for (StructField field : schema .fields ()) {
78+ if (!UnsafeRow .readableFieldTypes .contains (field .dataType ())) {
79+ return false ;
80+ }
81+ }
82+ return true ;
83+ }
84+
85+ /**
86+ * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
87+ * schema, false otherwise.
88+ */
89+ public static boolean supportsAggregationBufferSchema (StructType schema ) {
90+ for (StructField field : schema .fields ()) {
91+ if (!UnsafeRow .settableFieldTypes .contains (field .dataType ())) {
92+ return false ;
93+ }
94+ }
95+ return true ;
96+ }
97+
9698 /**
9799 * Create a new UnsafeFixedWidthAggregationMap.
98100 *
99- * @param initProjection the default value for new keys (a "zero" of the agg. function)
100- * @param keyConverter the converter of the grouping key , used for row conversion.
101- * @param bufferConverter the converter of the aggregation buffer , used for row conversion.
101+ * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
102+ * @param aggregationBufferSchema the schema of the aggregation buffer , used for row conversion.
103+ * @param groupingKeySchema the schema of the grouping key , used for row conversion.
102104 * @param memoryManager the memory manager used to allocate our Unsafe memory structures.
103105 * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
104106 * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
105107 */
106108 public UnsafeFixedWidthAggregationMap (
107- Function1 < InternalRow , InternalRow > initProjection ,
108- UnsafeRowConverter keyConverter ,
109- UnsafeRowConverter bufferConverter ,
109+ InternalRow emptyAggregationBuffer ,
110+ StructType aggregationBufferSchema ,
111+ StructType groupingKeySchema ,
110112 TaskMemoryManager memoryManager ,
111113 int initialCapacity ,
112114 boolean enablePerfMetrics ) {
113- this .initProjection = initProjection ;
114- this . keyConverter = keyConverter ;
115- this .bufferConverter = bufferConverter ;
116- this .enablePerfMetrics = enablePerfMetrics ;
117-
115+ this .emptyAggregationBuffer =
116+ convertToUnsafeRow ( emptyAggregationBuffer , aggregationBufferSchema ) ;
117+ this .aggregationBufferSchema = aggregationBufferSchema ;
118+ this .groupingKeyToUnsafeRowConverter = new UnsafeRowConverter ( groupingKeySchema ) ;
119+ this . groupingKeySchema = groupingKeySchema ;
118120 this .map = new BytesToBytesMap (memoryManager , initialCapacity , enablePerfMetrics );
119- this .keyPool = new UniqueObjectPool ( 100 ) ;
120- this . bufferPool = new ObjectPool ( initialCapacity );
121+ this .enablePerfMetrics = enablePerfMetrics ;
122+ }
121123
122- InternalRow initRow = initProjection .apply (emptyRow );
123- int emptyBufferSize = bufferConverter .getSizeRequirement (initRow );
124- this .emptyBuffer = new byte [emptyBufferSize ];
125- int writtenLength = bufferConverter .writeRow (
126- initRow , emptyBuffer , PlatformDependent .BYTE_ARRAY_OFFSET , emptyBufferSize ,
127- bufferPool );
128- assert (writtenLength == emptyBuffer .length ): "Size requirement calculation was wrong!" ;
129- // re-use the empty buffer only when there is no object saved in pool.
130- reuseEmptyBuffer = bufferPool .size () == 0 ;
124+ /**
125+ * Convert a Java object row into an UnsafeRow, allocating it into a new byte array.
126+ */
127+ private static byte [] convertToUnsafeRow (InternalRow javaRow , StructType schema ) {
128+ final UnsafeRowConverter converter = new UnsafeRowConverter (schema );
129+ final int size = converter .getSizeRequirement (javaRow );
130+ final byte [] unsafeRow = new byte [size ];
131+ final int writtenLength =
132+ converter .writeRow (javaRow , unsafeRow , PlatformDependent .BYTE_ARRAY_OFFSET , size );
133+ assert (writtenLength == unsafeRow .length ): "Size requirement calculation was wrong!" ;
134+ return unsafeRow ;
131135 }
132136
133137 /**
134138 * Return the aggregation buffer for the current group. For efficiency, all calls to this method
135139 * return the same object.
136140 */
137141 public UnsafeRow getAggregationBuffer (InternalRow groupingKey ) {
138- final int groupingKeySize = keyConverter .getSizeRequirement (groupingKey );
142+ final int groupingKeySize = groupingKeyToUnsafeRowConverter .getSizeRequirement (groupingKey );
139143 // Make sure that the buffer is large enough to hold the key. If it's not, grow it:
140144 if (groupingKeySize > groupingKeyConversionScratchSpace .length ) {
141145 groupingKeyConversionScratchSpace = new byte [groupingKeySize ];
142146 }
143- final int actualGroupingKeySize = keyConverter .writeRow (
147+ final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter .writeRow (
144148 groupingKey ,
145149 groupingKeyConversionScratchSpace ,
146150 PlatformDependent .BYTE_ARRAY_OFFSET ,
147- groupingKeySize ,
148- keyPool );
151+ groupingKeySize );
149152 assert (groupingKeySize == actualGroupingKeySize ) : "Size requirement calculation was wrong!" ;
150153
151154 // Probe our map using the serialized key
@@ -156,32 +159,25 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
156159 if (!loc .isDefined ()) {
157160 // This is the first time that we've seen this grouping key, so we'll insert a copy of the
158161 // empty aggregation buffer into the map:
159- if (!reuseEmptyBuffer ) {
160- // There is some objects referenced by emptyBuffer, so generate a new one
161- InternalRow initRow = initProjection .apply (emptyRow );
162- bufferConverter .writeRow (initRow , emptyBuffer , PlatformDependent .BYTE_ARRAY_OFFSET ,
163- groupingKeySize , bufferPool );
164- }
165162 loc .putNewKey (
166163 groupingKeyConversionScratchSpace ,
167164 PlatformDependent .BYTE_ARRAY_OFFSET ,
168165 groupingKeySize ,
169- emptyBuffer ,
166+ emptyAggregationBuffer ,
170167 PlatformDependent .BYTE_ARRAY_OFFSET ,
171- emptyBuffer .length
168+ emptyAggregationBuffer .length
172169 );
173170 }
174171
175172 // Reset the pointer to point to the value that we just stored or looked up:
176173 final MemoryLocation address = loc .getValueAddress ();
177- currentBuffer .pointTo (
174+ currentAggregationBuffer .pointTo (
178175 address .getBaseObject (),
179176 address .getBaseOffset (),
180- bufferConverter .numFields (),
181- loc .getValueLength (),
182- bufferPool
177+ aggregationBufferSchema .length (),
178+ loc .getValueLength ()
183179 );
184- return currentBuffer ;
180+ return currentAggregationBuffer ;
185181 }
186182
187183 /**
@@ -217,16 +213,14 @@ public MapEntry next() {
217213 entry .key .pointTo (
218214 keyAddress .getBaseObject (),
219215 keyAddress .getBaseOffset (),
220- keyConverter .numFields (),
221- loc .getKeyLength (),
222- keyPool
216+ groupingKeySchema .length (),
217+ loc .getKeyLength ()
223218 );
224219 entry .value .pointTo (
225220 valueAddress .getBaseObject (),
226221 valueAddress .getBaseOffset (),
227- bufferConverter .numFields (),
228- loc .getValueLength (),
229- bufferPool
222+ aggregationBufferSchema .length (),
223+ loc .getValueLength ()
230224 );
231225 return entry ;
232226 }
@@ -254,8 +248,6 @@ public void printPerfMetrics() {
254248 System .out .println ("Number of hash collisions: " + map .getNumHashCollisions ());
255249 System .out .println ("Time spent resizing (ns): " + map .getTimeSpentResizingNs ());
256250 System .out .println ("Total memory consumption (bytes): " + map .getTotalMemoryConsumption ());
257- System .out .println ("Number of unique objects in keys: " + keyPool .size ());
258- System .out .println ("Number of objects in buffers: " + bufferPool .size ());
259251 }
260252
261253}
0 commit comments