Skip to content

Commit

Permalink
Use iterator that can advance, add random value unit test
Browse files Browse the repository at this point in the history
Signed-off-by: bowenlan-amzn <[email protected]>
  • Loading branch information
bowenlan-amzn committed Jan 7, 2025
1 parent dbc7b5e commit 941a5c2
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.lookup.SearchLookup;
import org.opensearch.search.query.BitmapDocValuesQuery;
import org.opensearch.search.query.BitmapIndexQuery;

import java.io.IOException;
import java.math.BigInteger;
Expand All @@ -97,7 +98,6 @@
import java.util.function.Function;
import java.util.function.Supplier;

import org.opensearch.search.query.BitmapIndexQuery;
import org.roaringbitmap.RoaringBitmap;

/**
Expand Down Expand Up @@ -1555,6 +1555,7 @@ public Scorer get(long leadCost) throws IOException {
final BytesRef encoded = new BytesRef(new byte[Integer.BYTES]);
Query query = new PointInSetQuery(field, 1, Integer.BYTES, new PointInSetQuery.Stream() {
final Iterator<Integer> iterator = bitmap.iterator();

@Override
public BytesRef next() {
int value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@
import org.apache.lucene.util.BytesRefIterator;
import org.apache.lucene.util.DocIdSetBuilder;
import org.apache.lucene.util.RamUsageEstimator;
import org.roaringbitmap.RoaringBitmap;

import java.io.IOException;
import java.util.Iterator;
import java.util.Objects;

import org.roaringbitmap.PeekableIntIterator;
import org.roaringbitmap.RoaringBitmap;

/**
* A query that matches all documents that contain a set of integer numbers represented by bitmap
*
Expand All @@ -50,12 +51,19 @@ public BitmapIndexQuery(String field, RoaringBitmap bitmap) {
this.field = field;
}

private static BytesRefIterator bitmapEncodedIterator(RoaringBitmap bitmap) {
return new BytesRefIterator() {
private final Iterator<Integer> iterator = bitmap.iterator();
interface BitmapIterator extends BytesRefIterator {
// wrap IntIterator.next()
BytesRef next();

// expose PeekableIntIterator.advanceIfNeeded, advance as long as the next value is smaller than target
void advance(byte[] target);
}

private static BitmapIterator bitmapEncodedIterator(RoaringBitmap bitmap) {
return new BitmapIterator() {
private final PeekableIntIterator iterator = bitmap.getIntIterator();
private final BytesRef encoded = new BytesRef(new byte[Integer.BYTES]);

@Override
public BytesRef next() {
int value;
if (iterator.hasNext()) {
Expand All @@ -66,6 +74,10 @@ public BytesRef next() {
IntPoint.encodeDimension(value, encoded.bytes, 0);
return encoded;
}

public void advance(byte[] target) {
iterator.advanceIfNeeded(IntPoint.decodeDimension(target, 0));
}
};
}

Expand All @@ -85,8 +97,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
final Weight weight = this;
LeafReader reader = context.reader();
// get point value
// only works for one dimension
// get the point value which should be one dimension, since bitmap saves integers
PointValues values = reader.getPointValues(field);
if (values == null) {
return null;
Expand Down Expand Up @@ -118,21 +129,20 @@ public long cost() {

@Override
public boolean isCacheable(LeafReaderContext ctx) {
// This query depend only on segment-immutable structure points
// This query depend only on segment-immutable structure points
return true;
}
};
}

private class MergePointVisitor implements PointValues.IntersectVisitor {
private final DocIdSetBuilder result;
private final BytesRefIterator iterator;
private final BitmapIterator iterator;
private BytesRef nextQueryPoint;
private final ArrayUtil.ByteArrayComparator comparator;
private DocIdSetBuilder.BulkAdder adder;

public MergePointVisitor(DocIdSetBuilder result)
throws IOException {
public MergePointVisitor(DocIdSetBuilder result) throws IOException {
this.result = result;
this.comparator = ArrayUtil.getUnsignedComparator(Integer.BYTES);
this.iterator = bitmapEncodedIterator(bitmap);
Expand Down Expand Up @@ -175,11 +185,8 @@ private boolean matches(byte[] packedValue) {
return true;
} else if (cmp < 0) {
// Query point is before index point, so we move to next query point
try {
nextQueryPoint = iterator.next();
} catch (IOException e) {
throw new RuntimeException(e);
}
iterator.advance(packedValue);
nextQueryPoint = iterator.next();
} else {
// Query point is after index point, so we don't collect and we return:
break;
Expand All @@ -191,19 +198,14 @@ private boolean matches(byte[] packedValue) {
@Override
public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
while (nextQueryPoint != null) {
int cmpMin =
comparator.compare(nextQueryPoint.bytes, nextQueryPoint.offset, minPackedValue, 0);
int cmpMin = comparator.compare(nextQueryPoint.bytes, nextQueryPoint.offset, minPackedValue, 0);
if (cmpMin < 0) {
// query point is before the start of this cell
try {
nextQueryPoint = iterator.next();
} catch (IOException e) {
throw new RuntimeException(e);
}
iterator.advance(minPackedValue);
nextQueryPoint = iterator.next();
continue;
}
int cmpMax =
comparator.compare(nextQueryPoint.bytes, nextQueryPoint.offset, maxPackedValue, 0);
int cmpMax = comparator.compare(nextQueryPoint.bytes, nextQueryPoint.offset, maxPackedValue, 0);
if (cmpMax > 0) {
// query point is after the end of this cell
return PointValues.Relation.CELL_OUTSIDE_QUERY;
Expand Down Expand Up @@ -260,7 +262,7 @@ public int hashCode() {

@Override
public long ramBytesUsed() {
return RamUsageEstimator.shallowSizeOfInstance(BitmapIndexQuery.class) + RamUsageEstimator.sizeOfObject(field)
+ RamUsageEstimator.sizeOfObject(bitmap);
return RamUsageEstimator.shallowSizeOfInstance(BitmapIndexQuery.class) + RamUsageEstimator.sizeOfObject(field) + RamUsageEstimator
.sizeOfObject(bitmap);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,9 @@
import org.apache.lucene.document.Field;
import org.apache.lucene.document.IntField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.opensearch.test.OpenSearchTestCase;
Expand All @@ -28,7 +23,6 @@

import java.io.IOException;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,22 @@
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.opensearch.common.Randomness;
import org.opensearch.test.OpenSearchTestCase;
import org.junit.After;
import org.junit.Before;
import org.opensearch.test.OpenSearchTestCase;
import org.roaringbitmap.RoaringBitmap;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.Set;

import org.roaringbitmap.RoaringBitmap;

public class BitmapIndexQueryTests extends OpenSearchTestCase {
private Directory dir;
private IndexWriter w;
Expand Down Expand Up @@ -86,11 +91,11 @@ public void testScore() throws IOException {
assertEquals(expected, actual);
}

// use doc values to get the actual value of the matching docs
// cannot directly check the docId because test can randomize segment numbers
static List<Integer> getMatchingValues(Weight weight, IndexReader reader) throws IOException {
List<Integer> actual = new LinkedList<>();
for (LeafReaderContext leaf : reader.leaves()) {
// use doc values to get the actual value of the matching docs and assert
// cannot directly check the docId because test can randomize segment numbers
SortedNumericDocValues dv = DocValues.getSortedNumeric(leaf.reader(), "product_id");
Scorer scorer = weight.scorer(leaf);
DocIdSetIterator disi = scorer.iterator();
Expand Down Expand Up @@ -138,4 +143,40 @@ public void testScoreMutilValues() throws IOException {
assertEquals(expected, actual);
}

public void testRandomDocumentsAndQueries() throws IOException {
Random random = Randomness.get();
int valueRange = 10_000; // the range of query values should be within indexed values

for (int i = 0; i < valueRange + 1; i++) {
Document d = new Document();
d.add(new IntField("product_id", i, Field.Store.NO));
w.addDocument(d);
}

w.commit();
reader = DirectoryReader.open(w);
searcher = newSearcher(reader);

// Generate random values for bitmap query
Set<Integer> queryValues = new HashSet<>();
int numberOfValues = 5;
for (int i = 0; i < numberOfValues; i++) {
int value = random.nextInt(valueRange) + 1;
queryValues.add(value);
}
RoaringBitmap bitmap = new RoaringBitmap();
bitmap.add(queryValues.stream().mapToInt(Integer::intValue).toArray());

BitmapIndexQuery query = new BitmapIndexQuery("product_id", bitmap);
Weight weight = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f);

Set<Integer> actualSet = new HashSet<>(getMatchingValues(weight, searcher.getIndexReader()));

List<Integer> expected = new ArrayList<>(queryValues);
Collections.sort(expected);
List<Integer> actual = new ArrayList<>(actualSet);
Collections.sort(actual);
assertEquals(expected, actual);
}

}

0 comments on commit 941a5c2

Please sign in to comment.