Skip to content

Commit 5fd0b5a

Browse files
committed
Fix _all boosting.
_all boosting used to rely on the fact that the TokenStream doesn't eagerly consume the input java.io.Reader. This fixes the issue by using binary search in order to find the right boost given a token's start offset. Close elastic#4315
1 parent 95889b4 commit 5fd0b5a

File tree

3 files changed

+93
-13
lines changed

3 files changed

+93
-13
lines changed

src/main/java/org/elasticsearch/common/lucene/all/AllEntries.java

+33-5
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,20 @@ public class AllEntries extends Reader {
4040
public static class Entry {
4141
private final String name;
4242
private final FastStringReader reader;
43+
private final int startOffset;
4344
private final float boost;
4445

45-
public Entry(String name, FastStringReader reader, float boost) {
46+
public Entry(String name, FastStringReader reader, int startOffset, float boost) {
4647
this.name = name;
4748
this.reader = reader;
49+
this.startOffset = startOffset;
4850
this.boost = boost;
4951
}
5052

53+
public int startOffset() {
54+
return startOffset;
55+
}
56+
5157
public String name() {
5258
return this.name;
5359
}
@@ -75,7 +81,15 @@ public void addText(String name, String text, float boost) {
7581
if (boost != 1.0f) {
7682
customBoost = true;
7783
}
78-
Entry entry = new Entry(name, new FastStringReader(text), boost);
84+
final int lastStartOffset;
85+
if (entries.isEmpty()) {
86+
lastStartOffset = -1;
87+
} else {
88+
final Entry last = entries.get(entries.size() - 1);
89+
lastStartOffset = last.startOffset() + last.reader().length();
90+
}
91+
final int startOffset = lastStartOffset + 1; // +1 because we insert a space between tokens
92+
Entry entry = new Entry(name, new FastStringReader(text), startOffset, boost);
7993
entries.add(entry);
8094
}
8195

@@ -129,8 +143,22 @@ public Set<String> fields() {
129143
return fields;
130144
}
131145

132-
public Entry current() {
133-
return this.current;
146+
// compute the boost for a token with the given startOffset
147+
public float boost(int startOffset) {
148+
int lo = 0, hi = entries.size() - 1;
149+
while (lo <= hi) {
150+
final int mid = (lo + hi) >>> 1;
151+
final int midOffset = entries.get(mid).startOffset();
152+
if (startOffset < midOffset) {
153+
hi = mid - 1;
154+
} else {
155+
lo = mid + 1;
156+
}
157+
}
158+
final int index = Math.max(0, hi); // protection against broken token streams
159+
assert entries.get(index).startOffset() <= startOffset;
160+
assert index == entries.size() - 1 || entries.get(index + 1).startOffset() > startOffset;
161+
return entries.get(index).boost();
134162
}
135163

136164
@Override
@@ -186,7 +214,7 @@ public int read(char[] cbuf, int off, int len) throws IOException {
186214
@Override
187215
public void close() {
188216
if (current != null) {
189-
current.reader().close();
217+
// no need to close, these are readers on strings
190218
current = null;
191219
}
192220
}

src/main/java/org/elasticsearch/common/lucene/all/AllTokenStream.java

+9-8
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.lucene.analysis.Analyzer;
2323
import org.apache.lucene.analysis.TokenFilter;
2424
import org.apache.lucene.analysis.TokenStream;
25+
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
2526
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute;
2627
import org.apache.lucene.util.BytesRef;
2728

@@ -42,11 +43,13 @@ public static TokenStream allTokenStream(String allFieldName, AllEntries allEntr
4243

4344
private final AllEntries allEntries;
4445

46+
private final OffsetAttribute offsetAttribute;
4547
private final PayloadAttribute payloadAttribute;
4648

4749
AllTokenStream(TokenStream input, AllEntries allEntries) {
4850
super(input);
4951
this.allEntries = allEntries;
52+
offsetAttribute = addAttribute(OffsetAttribute.class);
5053
payloadAttribute = addAttribute(PayloadAttribute.class);
5154
}
5255

@@ -59,14 +62,12 @@ public final boolean incrementToken() throws IOException {
5962
if (!input.incrementToken()) {
6063
return false;
6164
}
62-
if (allEntries.current() != null) {
63-
float boost = allEntries.current().boost();
64-
if (boost != 1.0f) {
65-
encodeFloat(boost, payloadSpare.bytes, payloadSpare.offset);
66-
payloadAttribute.setPayload(payloadSpare);
67-
} else {
68-
payloadAttribute.setPayload(null);
69-
}
65+
final float boost = allEntries.boost(offsetAttribute.startOffset());
66+
if (boost != 1.0f) {
67+
encodeFloat(boost, payloadSpare.bytes, payloadSpare.offset);
68+
payloadAttribute.setPayload(payloadSpare);
69+
} else {
70+
payloadAttribute.setPayload(null);
7071
}
7172
return true;
7273
}

src/test/java/org/elasticsearch/common/lucene/all/SimpleAllTests.java

+51
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919

2020
package org.elasticsearch.common.lucene.all;
2121

22+
import org.apache.lucene.analysis.TokenStream;
23+
import org.apache.lucene.analysis.core.WhitespaceAnalyzer;
24+
import org.apache.lucene.analysis.payloads.PayloadHelper;
25+
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
26+
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute;
2227
import org.apache.lucene.document.Document;
2328
import org.apache.lucene.document.Field;
2429
import org.apache.lucene.document.StoredField;
@@ -27,6 +32,7 @@
2732
import org.apache.lucene.search.*;
2833
import org.apache.lucene.store.Directory;
2934
import org.apache.lucene.store.RAMDirectory;
35+
import org.apache.lucene.util.BytesRef;
3036
import org.elasticsearch.common.lucene.Lucene;
3137
import org.elasticsearch.test.ElasticsearchTestCase;
3238
import org.junit.Test;
@@ -40,6 +46,51 @@
4046
*/
4147
public class SimpleAllTests extends ElasticsearchTestCase {
4248

49+
@Test
50+
public void testBoostOnEagerTokenizer() throws Exception {
51+
AllEntries allEntries = new AllEntries();
52+
allEntries.addText("field1", "all", 2.0f);
53+
allEntries.addText("field2", "your", 1.0f);
54+
allEntries.addText("field1", "boosts", 0.5f);
55+
allEntries.reset();
56+
// whitespace analyzer's tokenizer reads characters eagerly on the contrary to the standard tokenizer
57+
final TokenStream ts = AllTokenStream.allTokenStream("any", allEntries, new WhitespaceAnalyzer(Lucene.VERSION));
58+
final CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class);
59+
final PayloadAttribute payloadAtt = ts.addAttribute(PayloadAttribute.class);
60+
ts.reset();
61+
for (int i = 0; i < 3; ++i) {
62+
assertTrue(ts.incrementToken());
63+
final String term;
64+
final float boost;
65+
switch (i) {
66+
case 0:
67+
term = "all";
68+
boost = 2;
69+
break;
70+
case 1:
71+
term = "your";
72+
boost = 1;
73+
break;
74+
case 2:
75+
term = "boosts";
76+
boost = 0.5f;
77+
break;
78+
default:
79+
throw new AssertionError();
80+
}
81+
assertEquals(term, termAtt.toString());
82+
final BytesRef payload = payloadAtt.getPayload();
83+
if (payload == null || payload.length == 0) {
84+
assertEquals(boost, 1f, 0.001f);
85+
} else {
86+
assertEquals(4, payload.length);
87+
final float b = PayloadHelper.decodeFloat(payload.bytes, payload.offset);
88+
assertEquals(boost, b, 0.001f);
89+
}
90+
}
91+
assertFalse(ts.incrementToken());
92+
}
93+
4394
@Test
4495
public void testAllEntriesRead() throws Exception {
4596
AllEntries allEntries = new AllEntries();

0 commit comments

Comments
 (0)