Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/141872.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
area: Vector Search
issues:
- 141746
pr: 141872
summary: "[GPU] Handle segments too big for MSAI segment access"
type: bug
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@

import com.nvidia.cuvs.CuVSMatrix;

import org.apache.lucene.store.MemorySegmentAccessInput;

import java.io.IOException;
import java.lang.foreign.MemorySegment;

interface DatasetUtils {

Expand All @@ -22,18 +20,5 @@ static DatasetUtils getInstance() {
}

/** Returns a Dataset over the vectors of type {@code dataType} in the input. */
CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int dims, CuVSMatrix.DataType dataType) throws IOException;

CuVSMatrix fromInput(
MemorySegmentAccessInput input,
int numVectors,
int dims,
int rowStride,
int columnStride,
CuVSMatrix.DataType dataType
) throws IOException;

/** Returns a Dataset over an input slice */
CuVSMatrix fromSlice(MemorySegmentAccessInput input, long pos, long len, int numVectors, int dims, CuVSMatrix.DataType dataType)
throws IOException;
CuVSMatrix fromInput(MemorySegment input, int numVectors, int dims, CuVSMatrix.DataType dataType);
Copy link
Copy Markdown
Contributor

@mayya-sharipova mayya-sharipova Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be not related to this PR, do we need DatasetUtils class at all? Looks like we only have a single DatasetUtilsImpl now.

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
import com.nvidia.cuvs.CuVSMatrix;
import com.nvidia.cuvs.spi.CuVSProvider;

import org.apache.lucene.store.MemorySegmentAccessInput;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.lang.invoke.MethodHandle;

Expand All @@ -23,7 +20,6 @@ class DatasetUtilsImpl implements DatasetUtils {
private static final DatasetUtils INSTANCE = new DatasetUtilsImpl();

private static final MethodHandle createDataset$mh = CuVSProvider.provider().newNativeMatrixBuilder();
private static final MethodHandle createDatasetWithStrides$mh = CuVSProvider.provider().newNativeMatrixBuilderWithStrides();

static DatasetUtils getInstance() {
return INSTANCE;
Expand All @@ -43,95 +39,21 @@ static CuVSMatrix fromMemorySegment(MemorySegment memorySegment, int size, int d
}
}

static CuVSMatrix fromMemorySegment(
MemorySegment memorySegment,
int size,
int dimensions,
int rowStride,
int columnStride,
CuVSMatrix.DataType dataType
) {
try {
return (CuVSMatrix) createDatasetWithStrides$mh.invokeExact(memorySegment, size, dimensions, rowStride, columnStride, dataType);
} catch (Throwable e) {
if (e instanceof Error err) {
throw err;
} else if (e instanceof RuntimeException re) {
throw re;
} else {
throw new RuntimeException(e);
}
}
}

private DatasetUtilsImpl() {}

@Override
public CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int dims, CuVSMatrix.DataType dataType) throws IOException {
if (numVectors < 0 || dims < 0) {
throwIllegalArgumentException(numVectors, dims);
public CuVSMatrix fromInput(MemorySegment input, int numVectors, int dims, CuVSMatrix.DataType dataType) {
if (input == null) {
throw new IllegalArgumentException("input cannot be null");
}
return createCuVSMatrix(input, 0L, input.length(), numVectors, dims, dataType);
}

@Override
public CuVSMatrix fromInput(
MemorySegmentAccessInput input,
int numVectors,
int dims,
int rowStride,
int columnStride,
CuVSMatrix.DataType dataType
) throws IOException {
if (numVectors < 0 || dims < 0) {
throwIllegalArgumentException(numVectors, dims);
}
return createCuVSMatrix(input, 0L, input.length(), numVectors, dims, rowStride, columnStride, dataType);
}

@Override
public CuVSMatrix fromSlice(MemorySegmentAccessInput input, long pos, long len, int numVectors, int dims, CuVSMatrix.DataType dataType)
throws IOException {
if (pos < 0 || len < 0) {
throw new IllegalArgumentException("pos and len must be positive");
}
return createCuVSMatrix(input, pos, len, numVectors, dims, dataType);
}

private static CuVSMatrix createCuVSMatrix(
MemorySegmentAccessInput input,
long pos,
long len,
int numVectors,
int dims,
CuVSMatrix.DataType dataType
) throws IOException {
MemorySegment ms = input.segmentSliceOrNull(pos, len);
assert ms != null; // TODO: this can be null if larger than 16GB or ...
final int byteSize = dataType == CuVSMatrix.DataType.FLOAT ? Float.BYTES : Byte.BYTES;
if (((long) numVectors * dims * byteSize) > ms.byteSize()) {
throwIllegalArgumentException(ms, numVectors, dims);
}
return fromMemorySegment(ms, numVectors, dims, dataType);
}

private static CuVSMatrix createCuVSMatrix(
MemorySegmentAccessInput input,
long pos,
long len,
int numVectors,
int dims,
int rowStride,
int columnStride,
CuVSMatrix.DataType dataType
) throws IOException {
MemorySegment ms = input.segmentSliceOrNull(pos, len);
assert ms != null;
final int byteSize = dataType == CuVSMatrix.DataType.FLOAT ? Float.BYTES : Byte.BYTES;
if (((long) numVectors * rowStride * byteSize) > ms.byteSize()) {
throwIllegalArgumentException(ms, numVectors, dims);
if (((long) numVectors * dims * byteSize) > input.byteSize()) {
throwIllegalArgumentException(input, numVectors, dims);
}
return fromMemorySegment(ms, numVectors, dims, rowStride, columnStride, dataType);
return fromMemorySegment(input, numVectors, dims, dataType);
}

static void throwIllegalArgumentException(MemorySegment ms, int numVectors, int dims) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@
import org.elasticsearch.logging.Logger;

import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
Expand All @@ -64,6 +62,8 @@
import static org.elasticsearch.gpu.codec.ES92GpuHnswVectorsFormat.LUCENE99_HNSW_VECTOR_INDEX_EXTENSION;
import static org.elasticsearch.gpu.codec.ES92GpuHnswVectorsFormat.LUCENE99_VERSION_CURRENT;
import static org.elasticsearch.gpu.codec.ES92GpuHnswVectorsFormat.MIN_NUM_VECTORS_FOR_GPU_BUILD;
import static org.elasticsearch.gpu.codec.MemorySegmentUtils.getContiguousMemorySegment;
import static org.elasticsearch.gpu.codec.MemorySegmentUtils.getContiguousPackedMemorySegment;

/**
* Writer that builds an Nvidia Carga Graph on GPU and then writes it into the Lucene99 HNSW format,
Expand Down Expand Up @@ -595,31 +595,27 @@ private void mergeByteVectorField(
// TODO: revert to directly pass data mapped with DatasetUtils.getInstance() to generateGpuGraphAndWriteMeta
// when cuvs has fixed this problem
int packedRowSize = fieldInfo.getVectorDimension();
long packedVectorsDataSize = (long) numVectors * packedRowSize;

try (var arena = Arena.ofConfined()) {
var packedSegment = arena.allocate(packedVectorsDataSize, 64);
MemorySegment sourceSegment = memorySegmentAccessInput.segmentSliceOrNull(0, memorySegmentAccessInput.length());

for (int i = 0; i < numVectors; i++) {
MemorySegment.copy(
sourceSegment,
(long) i * sourceRowPitch,
packedSegment,
(long) i * packedRowSize,
packedRowSize
);
}

try (
var dataset = DatasetUtilsImpl.fromMemorySegment(packedSegment, numVectors, packedRowSize, dataType);
var resourcesHolder = new ResourcesHolder(
cuVSResourceManager,
cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType, cagraIndexParams)
)
) {
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams);
}
try (
var packedSegmentHolder = getContiguousPackedMemorySegment(
memorySegmentAccessInput,
mergeState.segmentInfo.dir,
mergeState.segmentInfo.name,
numVectors,
sourceRowPitch,
packedRowSize
);
var dataset = DatasetUtilsImpl.fromMemorySegment(
packedSegmentHolder.memorySegment(),
numVectors,
packedRowSize,
dataType
);
var resourcesHolder = new ResourcesHolder(
cuVSResourceManager,
cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType, cagraIndexParams)
)
) {
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams);
}
} else {
logger.info(
Expand Down Expand Up @@ -690,10 +686,15 @@ private void mergeFloatVectorField(
IndexInput slice = vectorValues.getSlice();
var input = FilterIndexInput.unwrapOnlyTest(slice);
if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput) {
// Direct access to mmapped file
// Fast path, possible direct access to mmapped file
try (
var memorySegmentHolder = getContiguousMemorySegment(
memorySegmentAccessInput,
mergeState.segmentInfo.dir,
mergeState.segmentInfo.name
);
var dataset = DatasetUtils.getInstance()
.fromInput(memorySegmentAccessInput, numVectors, fieldInfo.getVectorDimension(), dataType);
.fromInput(memorySegmentHolder.memorySegment(), numVectors, fieldInfo.getVectorDimension(), dataType);
var resourcesHolder = new ResourcesHolder(
cuVSResourceManager,
cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType, cagraIndexParams)
Expand Down
Loading