|
11 | 11 |
|
12 | 12 | import org.apache.lucene.codecs.hnsw.FlatVectorsReader; |
13 | 13 | import org.apache.lucene.index.FieldInfo; |
14 | | -import org.apache.lucene.index.FloatVectorValues; |
15 | 14 | import org.apache.lucene.index.SegmentReadState; |
16 | 15 | import org.apache.lucene.index.VectorSimilarityFunction; |
17 | 16 | import org.apache.lucene.search.KnnCollector; |
|
20 | 19 | import org.apache.lucene.util.VectorUtil; |
21 | 20 | import org.apache.lucene.util.hnsw.NeighborQueue; |
22 | 21 | import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; |
| 22 | +import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats; |
23 | 23 | import org.elasticsearch.simdvec.ES91OSQVectorsScorer; |
24 | 24 | import org.elasticsearch.simdvec.ESVectorUtil; |
25 | 25 |
|
26 | 26 | import java.io.IOException; |
| 27 | +import java.util.Map; |
27 | 28 | import java.util.function.IntPredicate; |
28 | 29 |
|
29 | 30 | import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS; |
|
38 | 39 | * Default implementation of {@link IVFVectorsReader}. It scores the posting lists centroids using |
39 | 40 | * brute force and then scores the top ones using the posting list. |
40 | 41 | */ |
41 | | -public class DefaultIVFVectorsReader extends IVFVectorsReader { |
| 42 | +public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeapStats { |
42 | 43 | private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1); |
43 | 44 |
|
44 | 45 | public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException { |
@@ -163,57 +164,9 @@ static float int4QuantizedScore( |
163 | 164 | } |
164 | 165 | } |
165 | 166 |
|
166 | | - static class OffHeapCentroidFloatVectorValues extends FloatVectorValues { |
167 | | - private final int numCentroids; |
168 | | - private final IndexInput input; |
169 | | - private final int dimension; |
170 | | - private final float[] centroid; |
171 | | - private final long centroidByteSize; |
172 | | - private int ord = -1; |
173 | | - |
174 | | - OffHeapCentroidFloatVectorValues(int numCentroids, IndexInput input, int dimension) { |
175 | | - this.numCentroids = numCentroids; |
176 | | - this.input = input; |
177 | | - this.dimension = dimension; |
178 | | - this.centroid = new float[dimension]; |
179 | | - this.centroidByteSize = dimension + 3 * Float.BYTES + Short.BYTES; |
180 | | - } |
181 | | - |
182 | | - @Override |
183 | | - public float[] vectorValue(int ord) throws IOException { |
184 | | - if (ord < 0 || ord >= numCentroids) { |
185 | | - throw new IllegalArgumentException("ord must be in [0, " + numCentroids + "]"); |
186 | | - } |
187 | | - if (ord == this.ord) { |
188 | | - return centroid; |
189 | | - } |
190 | | - readQuantizedCentroid(ord); |
191 | | - return centroid; |
192 | | - } |
193 | | - |
194 | | - private void readQuantizedCentroid(int centroidOrdinal) throws IOException { |
195 | | - if (centroidOrdinal == ord) { |
196 | | - return; |
197 | | - } |
198 | | - input.seek(numCentroids * centroidByteSize + (long) Float.BYTES * dimension * centroidOrdinal); |
199 | | - input.readFloats(centroid, 0, centroid.length); |
200 | | - ord = centroidOrdinal; |
201 | | - } |
202 | | - |
203 | | - @Override |
204 | | - public int dimension() { |
205 | | - return dimension; |
206 | | - } |
207 | | - |
208 | | - @Override |
209 | | - public int size() { |
210 | | - return numCentroids; |
211 | | - } |
212 | | - |
213 | | - @Override |
214 | | - public FloatVectorValues copy() throws IOException { |
215 | | - return new OffHeapCentroidFloatVectorValues(numCentroids, input.clone(), dimension); |
216 | | - } |
| 167 | + @Override |
| 168 | + public Map<String, Long> getOffHeapByteSize(FieldInfo fieldInfo) { |
| 169 | + return Map.of(); |
217 | 170 | } |
218 | 171 |
|
219 | 172 | private static class MemorySegmentPostingsVisitor implements PostingVisitor { |
|
0 commit comments