Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

package org.elasticsearch.gpu.codec;

import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
Expand Down Expand Up @@ -132,14 +133,15 @@ public void testGetContiguousMemorySegmentBelowMaxChunkSize() throws IOException
var data = randomByteArrayOfLength(dataSize);

try (FSDirectory dir = new MMapDirectory(createTempDir(), maxChunkSize)) {
Directory wrappedDir = maybeWrapDirectoryInFilterDirectory(dir);
try (IndexOutput out = dir.createOutput("tests.bin7", IOContext.DEFAULT)) {
out.writeBytes(data, 0, dataSize);
}

try (IndexInput in = dir.openInput("tests.bin7", IOContext.DEFAULT)) {

var msai = (MemorySegmentAccessInput) in;
var holder = MemorySegmentUtils.getContiguousMemorySegment(msai, dir, "tests.bin9");
var holder = MemorySegmentUtils.getContiguousMemorySegment(msai, wrappedDir, "tests.bin9");

assertThat(holder, isA(MemorySegmentUtils.DirectMemorySegmentHolder.class));
assertNotNull(holder.memorySegment());
Expand All @@ -155,13 +157,14 @@ public void testGetContiguousMemorySegmentAboveMaxChunkSize() throws IOException
var data = randomByteArrayOfLength(dataSize);

try (FSDirectory dir = new MMapDirectory(createTempDir(), maxChunkSize)) {
Directory wrappedDir = maybeWrapDirectoryInFilterDirectory(dir);
try (IndexOutput out = dir.createOutput("tests.bin10", IOContext.DEFAULT)) {
out.writeBytes(data, 0, dataSize);
}

try (IndexInput in = dir.openInput("tests.bin10", IOContext.DEFAULT)) {
var msai = (MemorySegmentAccessInput) in;
try (var holder = MemorySegmentUtils.getContiguousMemorySegment(msai, dir, "tests.bin12")) {
try (var holder = MemorySegmentUtils.getContiguousMemorySegment(msai, wrappedDir, "tests.bin12")) {

assertThat(holder, isA(MemorySegmentUtils.FileBackedMemorySegmentHolder.class));
assertNotNull(holder.memorySegment());
Expand All @@ -171,4 +174,91 @@ public void testGetContiguousMemorySegmentAboveMaxChunkSize() throws IOException
}
}
}

public void testGetContiguousPackedMemorySegmentBelowMaxChunkSize() throws IOException {
var maxChunkSize = 2 * 1024 * 1024;
int packedRowSize = randomIntBetween(4, 200);
int paddingSize = randomIntBetween(1, 20);
int sourceRowPitch = packedRowSize + paddingSize;
int numVectors = randomIntBetween(1, 100);
int dataSize = numVectors * sourceRowPitch;
assumeTrue("data must fit below max chunk size", dataSize < maxChunkSize);

var data = randomByteArrayOfLength(dataSize);

try (FSDirectory dir = new MMapDirectory(createTempDir(), maxChunkSize)) {
Directory wrappedDir = maybeWrapDirectoryInFilterDirectory(dir);
try (IndexOutput out = dir.createOutput("tests.bin13", IOContext.DEFAULT)) {
out.writeBytes(data, 0, dataSize);
}
try (IndexInput in = dir.openInput("tests.bin13", IOContext.DEFAULT)) {
var msai = (MemorySegmentAccessInput) in;
try (
var holder = MemorySegmentUtils.getContiguousPackedMemorySegment(
msai,
wrappedDir,
"tests.bin14",
numVectors,
sourceRowPitch,
packedRowSize
)
) {
assertThat(holder, isA(MemorySegmentUtils.ArenaMemorySegmentHolder.class));
assertNotNull(holder.memorySegment());
assertThat(holder.memorySegment().byteSize(), equalTo((long) numVectors * packedRowSize));

byte[] packed = holder.memorySegment().toArray(ValueLayout.JAVA_BYTE);
for (int i = 0; i < numVectors; i++) {
for (int j = 0; j < packedRowSize; j++) {
assertEquals(data[i * sourceRowPitch + j], packed[i * packedRowSize + j]);
}
}
}
}
}
}

public void testGetContiguousPackedMemorySegmentAboveMaxChunkSize() throws IOException {
var maxChunkSize = 1000;
int packedRowSize = randomIntBetween(4, 200);
int paddingSize = randomIntBetween(1, 20);
int sourceRowPitch = packedRowSize + paddingSize;
int numVectors = randomIntBetween(1, 100);
int dataSize = numVectors * sourceRowPitch;
assumeTrue("data must exceed max chunk size", dataSize > maxChunkSize);

var data = randomByteArrayOfLength(dataSize);

try (FSDirectory dir = new MMapDirectory(createTempDir(), maxChunkSize)) {
Directory wrappedDir = maybeWrapDirectoryInFilterDirectory(dir);
try (IndexOutput out = dir.createOutput("tests.bin15", IOContext.DEFAULT)) {
out.writeBytes(data, 0, dataSize);
}

try (IndexInput in = dir.openInput("tests.bin15", IOContext.DEFAULT)) {
var msai = (MemorySegmentAccessInput) in;
try (
var holder = MemorySegmentUtils.getContiguousPackedMemorySegment(
msai,
wrappedDir,
"tests.bin16",
numVectors,
sourceRowPitch,
packedRowSize
)
) {
assertThat(holder, isA(MemorySegmentUtils.FileBackedMemorySegmentHolder.class));
assertNotNull(holder.memorySegment());
assertThat(holder.memorySegment().byteSize(), equalTo((long) numVectors * packedRowSize));

byte[] packed = holder.memorySegment().toArray(ValueLayout.JAVA_BYTE);
for (int i = 0; i < numVectors; i++) {
for (int j = 0; j < packedRowSize; j++) {
assertEquals(data[i * sourceRowPitch + j], packed[i * packedRowSize + j]);
}
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
import org.apache.logging.log4j.status.StatusLogger;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FilterDirectory;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.LuceneTestCase.SuppressCodecs;
import org.apache.lucene.tests.util.TestRuleMarkFailure;
Expand Down Expand Up @@ -3207,6 +3209,20 @@ public static BytesRef embedInRandomBytes(BytesRef bytesRef) {
return new BytesRef(newBytesArray, offset, bytesRef.length);
}

/**
* Randomly wraps a {@link Directory} in zero or more {@link FilterDirectory} layers, simulating how Elasticsearch
* wraps directories in production (e.g. {@code Store.StoreDirectory -> ByteSizeCachingDirectory -> MMapDirectory}).
* Use this when testing code that receives a {@link Directory} and must tolerate wrapper layers.
*/
public static Directory maybeWrapDirectoryInFilterDirectory(Directory dir) {
Directory wrapped = dir;
int layers = randomIntBetween(0, 3);
for (int i = 0; i < layers; i++) {
wrapped = new FilterDirectory(wrapped) {};
}
return wrapped;
}

private static boolean previousFailureSkipsRemaining;
@Rule
public final TestWatcher previousFailureSkipsRemainingRule = new TestWatcher() {
Expand Down
Loading