2424import scala .collection .Iterator ;
2525import scala .math .Ordering ;
2626
27+ import com .google .common .annotations .VisibleForTesting ;
28+
2729import org .apache .spark .SparkEnv ;
2830import org .apache .spark .TaskContext ;
2931import org .apache .spark .sql .AbstractScalaRowIterator ;
3032import org .apache .spark .sql .catalyst .InternalRow ;
3133import org .apache .spark .sql .catalyst .expressions .UnsafeRow ;
3234import org .apache .spark .sql .catalyst .expressions .UnsafeRowConverter ;
35+ import org .apache .spark .sql .catalyst .util .ObjectPool ;
3336import org .apache .spark .sql .types .StructType ;
3437import org .apache .spark .unsafe .PlatformDependent ;
3538import org .apache .spark .util .collection .unsafe .sort .PrefixComparator ;
@@ -41,61 +44,70 @@ final class UnsafeExternalRowSorter {
4144
4245 private final StructType schema ;
4346 private final UnsafeRowConverter rowConverter ;
44- private final RowComparator rowComparator ;
45- private final PrefixComparator prefixComparator ;
4647 private final Function1 <InternalRow , Long > prefixComputer ;
48+ private final ObjectPool objPool = new ObjectPool (128 );
49+ private final UnsafeExternalSorter sorter ;
50+ private byte [] rowConversionBuffer = new byte [1024 * 8 ];
4751
4852 public UnsafeExternalRowSorter (
4953 StructType schema ,
5054 Ordering <InternalRow > ordering ,
5155 PrefixComparator prefixComparator ,
5256 // TODO: if possible, avoid this boxing of the return value
53- Function1 <InternalRow , Long > prefixComputer ) {
57+ Function1 <InternalRow , Long > prefixComputer ) throws IOException {
5458 this .schema = schema ;
5559 this .rowConverter = new UnsafeRowConverter (schema );
56- this .rowComparator = new RowComparator (ordering , schema );
57- this .prefixComparator = prefixComparator ;
5860 this .prefixComputer = prefixComputer ;
59- }
60-
61- public Iterator <InternalRow > sort (Iterator <InternalRow > inputIterator ) throws IOException {
6261 final SparkEnv sparkEnv = SparkEnv .get ();
6362 final TaskContext taskContext = TaskContext .get ();
64- byte [] rowConversionBuffer = new byte [1024 * 8 ];
65- final UnsafeExternalSorter sorter = new UnsafeExternalSorter (
63+ sorter = new UnsafeExternalSorter (
6664 taskContext .taskMemoryManager (),
6765 sparkEnv .shuffleMemoryManager (),
6866 sparkEnv .blockManager (),
6967 taskContext ,
70- rowComparator ,
68+ new RowComparator ( ordering , schema . length (), objPool ) ,
7169 prefixComparator ,
7270 4096 ,
7371 sparkEnv .conf ()
7472 );
73+ }
74+
75+ @ VisibleForTesting
76+ void insertRow (InternalRow row ) throws IOException {
77+ final int sizeRequirement = rowConverter .getSizeRequirement (row );
78+ if (sizeRequirement > rowConversionBuffer .length ) {
79+ rowConversionBuffer = new byte [sizeRequirement ];
80+ } else {
81+ // Zero out the buffer that's used to hold the current row. This is necessary in order
82+ // to ensure that rows hash properly, since garbage data from the previous row could
83+ // otherwise end up as padding in this row. As a performance optimization, we only zero
84+ // out the portion of the buffer that we'll actually write to.
85+ Arrays .fill (rowConversionBuffer , 0 , sizeRequirement , (byte ) 0 );
86+ }
87+ final int bytesWritten = rowConverter .writeRow (
88+ row , rowConversionBuffer , PlatformDependent .BYTE_ARRAY_OFFSET , objPool );
89+ assert (bytesWritten == sizeRequirement );
90+ final long prefix = prefixComputer .apply (row );
91+ sorter .insertRecord (
92+ rowConversionBuffer ,
93+ PlatformDependent .BYTE_ARRAY_OFFSET ,
94+ sizeRequirement ,
95+ prefix
96+ );
97+ }
98+
99+ @ VisibleForTesting
100+ void spill () throws IOException {
101+ sorter .spill ();
102+ }
103+
104+ private void cleanupResources () {
105+ sorter .freeMemory ();
106+ }
107+
108+ @ VisibleForTesting
109+ Iterator <InternalRow > sort () throws IOException {
75110 try {
76- while (inputIterator .hasNext ()) {
77- final InternalRow row = inputIterator .next ();
78- final int sizeRequirement = rowConverter .getSizeRequirement (row );
79- if (sizeRequirement > rowConversionBuffer .length ) {
80- rowConversionBuffer = new byte [sizeRequirement ];
81- } else {
82- // Zero out the buffer that's used to hold the current row. This is necessary in order
83- // to ensure that rows hash properly, since garbage data from the previous row could
84- // otherwise end up as padding in this row. As a performance optimization, we only zero
85- // out the portion of the buffer that we'll actually write to.
86- Arrays .fill (rowConversionBuffer , 0 , sizeRequirement , (byte ) 0 );
87- }
88- final int bytesWritten =
89- rowConverter .writeRow (row , rowConversionBuffer , PlatformDependent .BYTE_ARRAY_OFFSET );
90- assert (bytesWritten == sizeRequirement );
91- final long prefix = prefixComputer .apply (row );
92- sorter .insertRecord (
93- rowConversionBuffer ,
94- PlatformDependent .BYTE_ARRAY_OFFSET ,
95- sizeRequirement ,
96- prefix
97- );
98- }
99111 final UnsafeSorterIterator sortedIterator = sorter .getSortedIterator ();
100112 return new AbstractScalaRowIterator () {
101113
@@ -113,7 +125,7 @@ public InternalRow next() {
113125 sortedIterator .loadNext ();
114126 if (hasNext ()) {
115127 row .pointTo (
116- sortedIterator .getBaseObject (), sortedIterator .getBaseOffset (), numFields , schema );
128+ sortedIterator .getBaseObject (), sortedIterator .getBaseOffset (), numFields , objPool );
117129 return row ;
118130 } else {
119131 final byte [] rowDataCopy = new byte [sortedIterator .getRecordLength ()];
@@ -125,14 +137,12 @@ public InternalRow next() {
125137 sortedIterator .getRecordLength ()
126138 );
127139 row .backingArray = rowDataCopy ;
128- row .pointTo (rowDataCopy , PlatformDependent .BYTE_ARRAY_OFFSET , numFields , schema );
140+ row .pointTo (rowDataCopy , PlatformDependent .BYTE_ARRAY_OFFSET , numFields , objPool );
129141 sorter .freeMemory ();
130142 return row ;
131143 }
132144 } catch (IOException e ) {
133- // TODO: we need to ensure that files are cleaned properly after an exception,
134- // so we need better cleanup methods than freeMemory().
135- sorter .freeMemory ();
145+ cleanupResources ();
136146 // Scala iterators don't declare any checked exceptions, so we need to use this hack
137147 // to re-throw the exception:
138148 PlatformDependent .throwException (e );
@@ -141,30 +151,36 @@ public InternalRow next() {
141151 };
142152 };
143153 } catch (IOException e ) {
144- // TODO: we need to ensure that files are cleaned properly after an exception,
145- // so we need better cleanup methods than freeMemory().
146- sorter .freeMemory ();
154+ cleanupResources ();
147155 throw e ;
148156 }
149157 }
150158
159+
160+ public Iterator <InternalRow > sort (Iterator <InternalRow > inputIterator ) throws IOException {
161+ while (inputIterator .hasNext ()) {
162+ insertRow (inputIterator .next ());
163+ }
164+ return sort ();
165+ }
166+
151167 private static final class RowComparator extends RecordComparator {
152- private final StructType schema ;
153168 private final Ordering <InternalRow > ordering ;
154169 private final int numFields ;
170+ private final ObjectPool objPool ;
155171 private final UnsafeRow row1 = new UnsafeRow ();
156172 private final UnsafeRow row2 = new UnsafeRow ();
157173
158- public RowComparator (Ordering <InternalRow > ordering , StructType schema ) {
159- this .schema = schema ;
160- this .numFields = schema .length ();
174+ public RowComparator (Ordering <InternalRow > ordering , int numFields , ObjectPool objPool ) {
175+ this .numFields = numFields ;
161176 this .ordering = ordering ;
177+ this .objPool = objPool ;
162178 }
163179
164180 @ Override
165181 public int compare (Object baseObj1 , long baseOff1 , Object baseObj2 , long baseOff2 ) {
166- row1 .pointTo (baseObj1 , baseOff1 , numFields , schema );
167- row2 .pointTo (baseObj2 , baseOff2 , numFields , schema );
182+ row1 .pointTo (baseObj1 , baseOff1 , numFields , objPool );
183+ row2 .pointTo (baseObj2 , baseOff2 , numFields , objPool );
168184 return ordering .compare (row1 , row2 );
169185 }
170186 }
0 commit comments