Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,20 @@ public class UnsafeAlignedOffset {

private static final int UAO_SIZE = Platform.unaligned() ? 4 : 8;

private static int TEST_UAO_SIZE = 0;

// used for test only
public static void setUaoSize(int size) {
assert size == 0 || size == 4 || size == 8;
TEST_UAO_SIZE = size;
}

public static int getUaoSize() {
return UAO_SIZE;
return TEST_UAO_SIZE == 0 ? UAO_SIZE : TEST_UAO_SIZE;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can use Utils.isTesting instead

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other tests might not set UAO size manually, then we'll get 0 in this case.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should handle the logic of setting/reverting the size in the test code.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then, we need to figure out all the test cases where used UnsafeAlignedOffset, right?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to do that? The default value of TEST_UAO_SIZE should be the same as UAO_SIZE, only if you want to test other values then you need to change it.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I get your point. Let me update it later.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm...unfortunately, unsafe project doesn't depend on core. So we can not access the Utils.

}

public static int getSize(Object object, long offset) {
switch (UAO_SIZE) {
switch (getUaoSize()) {
case 4:
return Platform.getInt(object, offset);
case 8:
Expand All @@ -46,7 +54,7 @@ public static int getSize(Object object, long offset) {
}

public static void putSize(Object object, long offset, int value) {
switch (UAO_SIZE) {
switch (getUaoSize()) {
case 4:
Platform.putInt(object, offset, value);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@
* probably be using sorting instead of hashing for better cache locality.
*
* The key and values under the hood are stored together, in the following format:
* Bytes 0 to 4: len(k) (key length in bytes) + len(v) (value length in bytes) + 4
* Bytes 4 to 8: len(k)
* Bytes 8 to 8 + len(k): key data
* Bytes 8 + len(k) to 8 + len(k) + len(v): value data
* Bytes 8 + len(k) + len(v) to 8 + len(k) + len(v) + 8: pointer to next pair
* First uaoSize bytes: len(k) (key length in bytes) + len(v) (value length in bytes) + uaoSize
* Next uaoSize bytes: len(k)
* Next len(k) bytes: key data
* Next len(v) bytes: value data
* Last 8 bytes: pointer to next pair
*
* This means that the first four bytes store the entire record (key + value) length. This format
* It means first uaoSize bytes store the entire record (key + value + uaoSize) length. This format
* is compatible with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter},
* so we can pass records from this map directly into the sorter to sort records in place.
*/
Expand Down Expand Up @@ -706,7 +706,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff
// Here, we'll copy the data into our data pages. Because we only store a relative offset from
// the key address instead of storing the absolute address of the value, the key and value
// must be stored in the same memory page.
// (8 byte key length) (key) (value) (8 byte pointer to next value)
// (total length) (key length) (key) (value) (8 byte pointer to next value)
int uaoSize = UnsafeAlignedOffset.getUaoSize();
final long recordLength = (2L * uaoSize) + klen + vlen + 8;
if (currentPage == null || currentPage.size() - pageCursor < recordLength) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ public void expandPointerArray(LongArray newArray) {

/**
* Inserts a record to be sorted. Assumes that the record pointer points to a record length
* stored as a 4-byte integer, followed by the record's bytes.
* stored as a uaoSize(4 or 8) bytes integer, followed by the record's bytes.
*
* @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}.
* @param keyPrefix a user-defined key prefix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.UnsafeAlignedOffset;

/**
* An implementation of `RowBasedKeyValueBatch` in which key-value records have variable lengths.
*
* The format for each record looks like this:
* The format for each record looks like this (in case of uaoSize = 4):
* [4 bytes total size = (klen + vlen + 4)] [4 bytes key size = klen]
* [UnsafeRow for key of length klen] [UnsafeRow for Value of length vlen]
* [8 bytes pointer to next]
Expand All @@ -41,18 +42,19 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB
@Override
public UnsafeRow appendRow(Object kbase, long koff, int klen,
Object vbase, long voff, int vlen) {
final long recordLength = 8L + klen + vlen + 8;
int uaoSize = UnsafeAlignedOffset.getUaoSize();
final long recordLength = 2 * uaoSize + klen + vlen + 8L;
// if run out of max supported rows or page size, return null
if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) {
return null;
}

long offset = page.getBaseOffset() + pageCursor;
final long recordOffset = offset;
Platform.putInt(base, offset, klen + vlen + 4);
Platform.putInt(base, offset + 4, klen);
UnsafeAlignedOffset.putSize(base, offset, klen + vlen + uaoSize);
UnsafeAlignedOffset.putSize(base, offset + uaoSize, klen);

offset += 8;
offset += 2 * uaoSize;
Platform.copyMemory(kbase, koff, base, offset, klen);
offset += klen;
Platform.copyMemory(vbase, voff, base, offset, vlen);
Expand All @@ -61,11 +63,11 @@ public UnsafeRow appendRow(Object kbase, long koff, int klen,

pageCursor += recordLength;

keyOffsets[numRows] = recordOffset + 8;
keyOffsets[numRows] = recordOffset + 2 * uaoSize;

keyRowId = numRows;
keyRow.pointTo(base, recordOffset + 8, klen);
valueRow.pointTo(base, recordOffset + 8 + klen, vlen);
keyRow.pointTo(base, recordOffset + 2 * uaoSize, klen);
valueRow.pointTo(base, recordOffset + 2 * uaoSize + klen, vlen);
numRows++;
return valueRow;
}
Expand All @@ -79,7 +81,7 @@ public UnsafeRow getKeyRow(int rowId) {
assert(rowId < numRows);
if (keyRowId != rowId) { // if keyRowId == rowId, desired keyRow is already cached
long offset = keyOffsets[rowId];
int klen = Platform.getInt(base, offset - 4);
int klen = UnsafeAlignedOffset.getSize(base, offset - UnsafeAlignedOffset.getUaoSize());
keyRow.pointTo(base, offset, klen);
// set keyRowId so we can check if desired row is cached
keyRowId = rowId;
Expand All @@ -99,9 +101,10 @@ public UnsafeRow getValueFromKey(int rowId) {
getKeyRow(rowId);
}
assert(rowId >= 0);
int uaoSize = UnsafeAlignedOffset.getUaoSize();
long offset = keyRow.getBaseOffset();
int klen = keyRow.getSizeInBytes();
int vlen = Platform.getInt(base, offset - 8) - klen - 4;
int vlen = UnsafeAlignedOffset.getSize(base, offset - uaoSize * 2) - klen - uaoSize;
valueRow.pointTo(base, offset + klen, vlen);
return valueRow;
}
Expand Down Expand Up @@ -141,14 +144,15 @@ public boolean next() {
return false;
}

totalLength = Platform.getInt(base, offsetInPage) - 4;
currentklen = Platform.getInt(base, offsetInPage + 4);
int uaoSize = UnsafeAlignedOffset.getUaoSize();
totalLength = UnsafeAlignedOffset.getSize(base, offsetInPage) - uaoSize;
currentklen = UnsafeAlignedOffset.getSize(base, offsetInPage + uaoSize);
currentvlen = totalLength - currentklen;

key.pointTo(base, offsetInPage + 8, currentklen);
value.pointTo(base, offsetInPage + 8 + currentklen, currentvlen);
key.pointTo(base, offsetInPage + 2 * uaoSize, currentklen);
value.pointTo(base, offsetInPage + 2 * uaoSize + currentklen, currentvlen);

offsetInPage += 8 + totalLength + 8;
offsetInPage += 2 * uaoSize + totalLength + 8;
recordsInPage -= 1;
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.KVIterator;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.UnsafeAlignedOffset;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryBlock;
Expand Down Expand Up @@ -141,9 +142,10 @@ public UnsafeKVExternalSorter(

// Get encoded memory address
// baseObject + baseOffset point to the beginning of the key data in the map, but that
// the KV-pair's length data is stored in the word immediately before that address
// the KV-pair's length data is stored at 2 * uaoSize bytes immediately before that address
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I don't get it here why the address is related to uaoSize?

Copy link
Copy Markdown
Member Author

@Ngone51 Ngone51 Apr 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The record format is:

(total length) (key length) (key) (value) (8 byte pointer to next value)
      |             |
   uaoSize       uaoSize

And we now get keyOffset, so we need back walk for 2 usao sizes to get the address of the record.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you mean we need to add back 2 * uaoSize to cover the space of storing total length and key length? If that's the case then it makes sense.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.

MemoryBlock page = loc.getMemoryPage();
long address = taskMemoryManager.encodePageNumberAndOffset(page, baseOffset - 8);
long address = taskMemoryManager.encodePageNumberAndOffset(page,
baseOffset - 2 * UnsafeAlignedOffset.getUaoSize());

// Compute prefix
row.pointTo(baseObject, baseOffset, loc.getKeyLength());
Expand Down Expand Up @@ -262,10 +264,11 @@ public int compare(
Object baseObj2,
long baseOff2,
int baseLen2) {
int uaoSize = UnsafeAlignedOffset.getUaoSize();
// Note that since ordering doesn't need the total length of the record, we just pass 0
// into the row.
row1.pointTo(baseObj1, baseOff1 + 4, 0);
row2.pointTo(baseObj2, baseOff2 + 4, 0);
row1.pointTo(baseObj1, baseOff1 + uaoSize, 0);
row2.pointTo(baseObj2, baseOff2 + uaoSize, 0);
return ordering.compare(row1, row2);
}
}
Expand All @@ -289,11 +292,12 @@ public boolean next() throws IOException {
long recordOffset = underlying.getBaseOffset();
int recordLen = underlying.getRecordLength();

// Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
// Note that recordLen = keyLen + valueLen + uaoSize (for the keyLen itself)
int uaoSize = UnsafeAlignedOffset.getUaoSize();
int keyLen = Platform.getInt(baseObj, recordOffset);
int valueLen = recordLen - keyLen - 4;
key.pointTo(baseObj, recordOffset + 4, keyLen);
value.pointTo(baseObj, recordOffset + 4 + keyLen, valueLen);
int valueLen = recordLen - keyLen - uaoSize;
key.pointTo(baseObj, recordOffset + uaoSize, keyLen);
value.pointTo(baseObj, recordOffset + uaoSize + keyLen, valueLen);

return true;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UnsafeAlignedOffset


class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction {
Expand Down Expand Up @@ -1055,30 +1056,35 @@ class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySu
Seq("true", "false").foreach { enableTwoLevelMaps =>
withSQLConf(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key ->
enableTwoLevelMaps) {
(1 to 3).foreach { fallbackStartsAt =>
withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" ->
s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") {
// Create a new df to make sure its physical operator picks up
// spark.sql.TungstenAggregate.testFallbackStartsAt.
// todo: remove it?
val newActual = Dataset.ofRows(spark, actual.logicalPlan)

QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match {
case Some(errorMessage) =>
val newErrorMessage =
s"""
|The following aggregation query failed when using HashAggregate with
|controlled fallback (it falls back to bytes to bytes map once it has processed
|${fallbackStartsAt - 1} input rows and to sort-based aggregation once it has
|processed $fallbackStartsAt input rows). The query is ${actual.queryExecution}
|
|$errorMessage
""".stripMargin

fail(newErrorMessage)
case None => // Success
Seq(4, 8).foreach { uaoSize =>
UnsafeAlignedOffset.setUaoSize(uaoSize)
(1 to 3).foreach { fallbackStartsAt =>
withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" ->
s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") {
// Create a new df to make sure its physical operator picks up
// spark.sql.TungstenAggregate.testFallbackStartsAt.
// todo: remove it?
val newActual = Dataset.ofRows(spark, actual.logicalPlan)

QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match {
case Some(errorMessage) =>
val newErrorMessage =
s"""
|The following aggregation query failed when using HashAggregate with
|controlled fallback (it falls back to bytes to bytes map once it has
|processed ${fallbackStartsAt - 1} input rows and to sort-based aggregation
|once it has processed $fallbackStartsAt input rows).
|The query is ${actual.queryExecution}
|$errorMessage
""".stripMargin

fail(newErrorMessage)
case None => // Success
}
Copy link
Copy Markdown
Member

@kiszk kiszk Apr 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not confident, but do we need to call setUaoSize(0) at the end of this test? This is because TEST_UAO_SIZE is static.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yeah. I think we have to reset it.

}
}
// reset static uaoSize to avoid affect other tests
UnsafeAlignedOffset.setUaoSize(0)
}
}
}
Expand Down