Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.LongValues;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.IntToIntFunction;
import org.apache.lucene.util.packed.PackedInts;
import org.apache.lucene.util.packed.PackedLongValues;
import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans;
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;
import org.elasticsearch.logging.LogManager;
Expand Down Expand Up @@ -46,7 +49,7 @@ public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVec
}

@Override
long[] buildAndWritePostingsLists(
LongValues buildAndWritePostingsLists(
FieldInfo fieldInfo,
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
Expand Down Expand Up @@ -81,7 +84,7 @@ long[] buildAndWritePostingsLists(
}
}
// write the posting lists
final long[] offsets = new long[centroidSupplier.size()];
final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT);
DocIdsWriter docIdsWriter = new DocIdsWriter();
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors(
Expand All @@ -93,7 +96,7 @@ long[] buildAndWritePostingsLists(
float[] centroid = centroidSupplier.centroid(c);
int[] cluster = assignmentsByCluster[c];
// TODO align???
offsets[c] = postingsOutput.getFilePointer();
offsets.add(postingsOutput.getFilePointer());
int size = cluster.length;
postingsOutput.writeVInt(size);
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
Expand All @@ -109,11 +112,11 @@ long[] buildAndWritePostingsLists(
printClusterQualityStatistics(assignmentsByCluster);
}

return offsets;
return offsets.build();
}

@Override
long[] buildAndWritePostingsLists(
LongValues buildAndWritePostingsLists(
FieldInfo fieldInfo,
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
Expand Down Expand Up @@ -199,7 +202,7 @@ long[] buildAndWritePostingsLists(
}
// now we can read the quantized vectors from the temporary file
try (IndexInput quantizedVectorsInput = mergeState.segmentInfo.dir.openInput(quantizedVectorsTempName, IOContext.DEFAULT)) {
final long[] offsets = new long[centroidSupplier.size()];
final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT);
OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors(
quantizedVectorsInput,
fieldInfo.getVectorDimension()
Expand All @@ -210,9 +213,9 @@ long[] buildAndWritePostingsLists(
float[] centroid = centroidSupplier.centroid(c);
int[] cluster = assignmentsByCluster[c];
boolean[] isOverspill = isOverspillByCluster[c];
// TODO align???
offsets[c] = postingsOutput.getFilePointer();
offsets.add(postingsOutput.getFilePointer());
int size = cluster.length;
// TODO align???
postingsOutput.writeVInt(size);
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]);
Expand All @@ -226,7 +229,7 @@ long[] buildAndWritePostingsLists(
if (logger.isDebugEnabled()) {
printClusterQualityStatistics(assignmentsByCluster);
}
return offsets;
return offsets.build();
}
}

Expand Down Expand Up @@ -270,7 +273,7 @@ void writeCentroids(
FieldInfo fieldInfo,
CentroidSupplier centroidSupplier,
float[] globalCentroid,
long[] offsets,
LongValues offsets,
IndexOutput centroidOutput
) throws IOException {

Expand Down Expand Up @@ -302,7 +305,7 @@ void writeCentroids(
// write the centroids
centroidOutput.writeBytes(buffer.array(), buffer.array().length);
// write the offset of this posting list
centroidOutput.writeLong(offsets[i]);
centroidOutput.writeLong(offsets.get(i));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.LongValues;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.SuppressForbidden;
Expand Down Expand Up @@ -126,11 +127,11 @@ abstract void writeCentroids(
FieldInfo fieldInfo,
CentroidSupplier centroidSupplier,
float[] globalCentroid,
long[] centroidOffset,
LongValues centroidOffset,
IndexOutput centroidOutput
) throws IOException;

abstract long[] buildAndWritePostingsLists(
abstract LongValues buildAndWritePostingsLists(
FieldInfo fieldInfo,
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
Expand All @@ -139,7 +140,7 @@ abstract long[] buildAndWritePostingsLists(
int[] overspillAssignments
) throws IOException;

abstract long[] buildAndWritePostingsLists(
abstract LongValues buildAndWritePostingsLists(
FieldInfo fieldInfo,
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
Expand All @@ -160,25 +161,24 @@ abstract CentroidSupplier createCentroidSupplier(
public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
rawVectorDelegate.flush(maxDoc, sortMap);
for (FieldWriter fieldWriter : fieldWriters) {
float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()];
final float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()];
// build a float vector values with random access
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc);
// build centroids
final CentroidAssignments centroidAssignments = calculateCentroids(fieldWriter.fieldInfo, floatVectorValues, globalCentroid);
// wrap centroids with a supplier
final CentroidSupplier centroidSupplier = new OnHeapCentroidSupplier(centroidAssignments.centroids());
// write posting lists
final long[] offsets = buildAndWritePostingsLists(
final LongValues offsets = buildAndWritePostingsLists(
fieldWriter.fieldInfo,
centroidSupplier,
floatVectorValues,
ivfClusters,
centroidAssignments.assignments(),
centroidAssignments.overspillAssignments()
);
assert offsets.length == centroidSupplier.size();
final long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
// write centroids
final long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
writeCentroids(fieldWriter.fieldInfo, centroidSupplier, globalCentroid, offsets, ivfCentroids);
final long centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
// write meta file
Expand Down Expand Up @@ -338,7 +338,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
calculatedGlobalCentroid
);
// write posting lists
final long[] offsets = buildAndWritePostingsLists(
final LongValues offsets = buildAndWritePostingsLists(
fieldInfo,
centroidSupplier,
floatVectorValues,
Expand All @@ -347,9 +347,8 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
assignments,
overspillAssignments
);
assert offsets.length == centroidSupplier.size();
centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
// write centroids
centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
writeCentroids(fieldInfo, centroidSupplier, calculatedGlobalCentroid, offsets, ivfCentroids);
centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
// write meta
Expand Down