Skip to content

Commit 62e9ced

Browse files
committed
Move the document field from the instance method to the constructor.
1 parent 8c9722d commit 62e9ced

File tree

2 files changed

+134
-115
lines changed

2 files changed

+134
-115
lines changed

x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java

Lines changed: 69 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.lucene.util.BytesRef;
12+
import org.elasticsearch.ExceptionsHelper;
1213
import org.elasticsearch.Version;
1314
import org.elasticsearch.common.logging.DeprecationLogger;
1415
import org.elasticsearch.script.ScoreScript;
@@ -17,6 +18,7 @@
1718
import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues.DenseVectorScriptDocValues;
1819
import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues.SparseVectorScriptDocValues;
1920

21+
import java.io.IOException;
2022
import java.nio.ByteBuffer;
2123
import java.util.List;
2224
import java.util.Map;
@@ -37,9 +39,12 @@ public class ScoreScriptUtils {
3739
public static class DenseVectorFunction {
3840
final ScoreScript scoreScript;
3941
final float[] queryVector;
42+
final VectorScriptDocValues.DenseVectorScriptDocValues docValues;
4043

41-
public DenseVectorFunction(ScoreScript scoreScript, List<Number> queryVector) {
42-
this(scoreScript, queryVector, false);
44+
public DenseVectorFunction(ScoreScript scoreScript,
45+
List<Number> queryVector,
46+
Object field) {
47+
this(scoreScript, queryVector, field, false);
4348
}
4449

4550
/**
@@ -51,6 +56,7 @@ public DenseVectorFunction(ScoreScript scoreScript, List<Number> queryVector) {
5156
*/
5257
public DenseVectorFunction(ScoreScript scoreScript,
5358
List<Number> queryVector,
59+
Object field,
5460
boolean normalizeQuery) {
5561
this.scoreScript = scoreScript;
5662

@@ -68,18 +74,22 @@ public DenseVectorFunction(ScoreScript scoreScript,
6874
this.queryVector[dim] /= queryMagnitude;
6975
}
7076
}
71-
}
7277

73-
BytesRef getEncodedVector(Object arg) {
74-
DenseVectorScriptDocValues docValues;
75-
if (arg instanceof DenseVectorScriptDocValues) {
76-
docValues = (DenseVectorScriptDocValues) arg;
78+
if (field instanceof DenseVectorScriptDocValues) {
79+
docValues = (DenseVectorScriptDocValues) field;
7780
deprecationLogger.deprecatedAndMaybeLog("vector_function_signature", DEPRECATION_MESSAGE);
7881
} else {
79-
String field = (String) arg;
80-
docValues = (DenseVectorScriptDocValues) scoreScript.getDoc().get(field);
82+
String fieldName = (String) field;
83+
docValues = (DenseVectorScriptDocValues) scoreScript.getDoc().get(fieldName);
8184
}
85+
}
8286

87+
BytesRef getEncodedVector() {
88+
try {
89+
docValues.setNextDocId(scoreScript._getDocId());
90+
} catch (IOException e) {
91+
throw ExceptionsHelper.convertToElastic(e);
92+
}
8393
return docValues.getEncodedValue();
8494
}
8595

@@ -99,12 +109,12 @@ void validateDocVector(BytesRef vector) {
99109
// Calculate l1 norm (Manhattan distance) between a query's dense vector and documents' dense vectors
100110
public static final class L1Norm extends DenseVectorFunction {
101111

102-
public L1Norm(ScoreScript scoreScript, List<Number> queryVector) {
103-
super(scoreScript, queryVector);
112+
public L1Norm(ScoreScript scoreScript, List<Number> queryVector, Object field) {
113+
super(scoreScript, queryVector, field);
104114
}
105115

106-
public double l1norm(Object arg) {
107-
BytesRef vector = getEncodedVector(arg);
116+
public double l1norm() {
117+
BytesRef vector = getEncodedVector();
108118
validateDocVector(vector);
109119
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
110120

@@ -120,12 +130,12 @@ public double l1norm(Object arg) {
120130
// Calculate l2 norm (Euclidean distance) between a query's dense vector and documents' dense vectors
121131
public static final class L2Norm extends DenseVectorFunction {
122132

123-
public L2Norm(ScoreScript scoreScript, List<Number> queryVector) {
124-
super(scoreScript, queryVector);
133+
public L2Norm(ScoreScript scoreScript, List<Number> queryVector, Object field) {
134+
super(scoreScript, queryVector, field);
125135
}
126136

127-
public double l2norm(Object arg) {
128-
BytesRef vector = getEncodedVector(arg);
137+
public double l2norm() {
138+
BytesRef vector = getEncodedVector();
129139
validateDocVector(vector);
130140
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
131141

@@ -141,12 +151,12 @@ public double l2norm(Object arg) {
141151
// Calculate a dot product between a query's dense vector and documents' dense vectors
142152
public static final class DotProduct extends DenseVectorFunction {
143153

144-
public DotProduct(ScoreScript scoreScript, List<Number> queryVector) {
145-
super(scoreScript, queryVector);
154+
public DotProduct(ScoreScript scoreScript, List<Number> queryVector, Object field) {
155+
super(scoreScript, queryVector, field);
146156
}
147157

148-
public double dotProduct(Object arg) {
149-
BytesRef vector = getEncodedVector(arg);
158+
public double dotProduct() {
159+
BytesRef vector = getEncodedVector();
150160
validateDocVector(vector);
151161
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
152162

@@ -161,12 +171,12 @@ public double dotProduct(Object arg) {
161171
// Calculate cosine similarity between a query's dense vector and documents' dense vectors
162172
public static final class CosineSimilarity extends DenseVectorFunction {
163173

164-
public CosineSimilarity(ScoreScript scoreScript, List<Number> queryVector) {
165-
super(scoreScript, queryVector, true);
174+
public CosineSimilarity(ScoreScript scoreScript, List<Number> queryVector, Object field) {
175+
super(scoreScript, queryVector, field, true);
166176
}
167177

168-
public double cosineSimilarity(Object arg) {
169-
BytesRef vector = getEncodedVector(arg);
178+
public double cosineSimilarity() {
179+
BytesRef vector = getEncodedVector();
170180
validateDocVector(vector);
171181
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
172182

@@ -199,9 +209,13 @@ public static class SparseVectorFunction {
199209
final float[] queryValues;
200210
final int[] queryDims;
201211

212+
final VectorScriptDocValues.SparseVectorScriptDocValues docValues;
213+
202214
// prepare queryVector once per script execution
203215
// queryVector represents a map of dimensions to values
204-
public SparseVectorFunction(ScoreScript scoreScript, Map<String, Number> queryVector) {
216+
public SparseVectorFunction(ScoreScript scoreScript,
217+
Map<String, Number> queryVector,
218+
Object field) {
205219
this.scoreScript = scoreScript;
206220
//break vector into two arrays dims and values
207221
int n = queryVector.size();
@@ -220,19 +234,23 @@ public SparseVectorFunction(ScoreScript scoreScript, Map<String, Number> queryVe
220234
// Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions
221235
sortSparseDimsFloatValues(queryDims, queryValues, n);
222236

223-
deprecationLogger.deprecatedAndMaybeLog("sparse_vector_function", SparseVectorFieldMapper.DEPRECATION_MESSAGE);
224-
}
225-
226-
BytesRef getEncodedVector(Object arg) {
227-
SparseVectorScriptDocValues docValues;
228-
if (arg instanceof SparseVectorScriptDocValues) {
229-
docValues = (SparseVectorScriptDocValues) arg;
237+
if (field instanceof SparseVectorScriptDocValues) {
238+
docValues = (SparseVectorScriptDocValues) field;
230239
deprecationLogger.deprecatedAndMaybeLog("vector_function_signature", DEPRECATION_MESSAGE);
231240
} else {
232-
String field = (String) arg;
233-
docValues = (SparseVectorScriptDocValues) scoreScript.getDoc().get(field);
241+
String fieldName = (String) field;
242+
docValues = (SparseVectorScriptDocValues) scoreScript.getDoc().get(fieldName);
234243
}
235244

245+
deprecationLogger.deprecatedAndMaybeLog("sparse_vector_function", SparseVectorFieldMapper.DEPRECATION_MESSAGE);
246+
}
247+
248+
BytesRef getEncodedVector() {
249+
try {
250+
docValues.setNextDocId(scoreScript._getDocId());
251+
} catch (IOException e) {
252+
throw ExceptionsHelper.convertToElastic(e);
253+
}
236254
return docValues.getEncodedValue();
237255
}
238256

@@ -245,12 +263,12 @@ public void validateDocVector(BytesRef vector) {
245263

246264
// Calculate l1 norm (Manhattan distance) between a query's sparse vector and documents' sparse vectors
247265
public static final class L1NormSparse extends SparseVectorFunction {
248-
public L1NormSparse(ScoreScript scoreScript,Map<String, Number> queryVector) {
249-
super(scoreScript, queryVector);
266+
public L1NormSparse(ScoreScript scoreScript,Map<String, Number> queryVector, Object docVector) {
267+
super(scoreScript, queryVector, docVector);
250268
}
251269

252-
public double l1normSparse(Object arg) {
253-
BytesRef vector = getEncodedVector(arg);
270+
public double l1normSparse() {
271+
BytesRef vector = getEncodedVector();
254272
validateDocVector(vector);
255273

256274
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
@@ -285,12 +303,12 @@ public double l1normSparse(Object arg) {
285303

286304
// Calculate l2 norm (Euclidean distance) between a query's sparse vector and documents' sparse vectors
287305
public static final class L2NormSparse extends SparseVectorFunction {
288-
public L2NormSparse(ScoreScript scoreScript, Map<String, Number> queryVector) {
289-
super(scoreScript, queryVector);
306+
public L2NormSparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
307+
super(scoreScript, queryVector, docVector);
290308
}
291309

292-
public double l2normSparse(Object arg) {
293-
BytesRef vector = getEncodedVector(arg);
310+
public double l2normSparse() {
311+
BytesRef vector = getEncodedVector();
294312
validateDocVector(vector);
295313

296314
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
@@ -328,12 +346,12 @@ public double l2normSparse(Object arg) {
328346

329347
// Calculate a dot product between a query's sparse vector and documents' sparse vectors
330348
public static final class DotProductSparse extends SparseVectorFunction {
331-
public DotProductSparse(ScoreScript scoreScript, Map<String, Number> queryVector) {
332-
super(scoreScript, queryVector);
349+
public DotProductSparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
350+
super(scoreScript, queryVector, docVector);
333351
}
334352

335-
public double dotProductSparse(Object arg) {
336-
BytesRef vector = getEncodedVector(arg);
353+
public double dotProductSparse() {
354+
BytesRef vector = getEncodedVector();
337355
validateDocVector(vector);
338356

339357
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
@@ -346,17 +364,17 @@ public double dotProductSparse(Object arg) {
346364
public static final class CosineSimilaritySparse extends SparseVectorFunction {
347365
final double queryVectorMagnitude;
348366

349-
public CosineSimilaritySparse(ScoreScript scoreScript, Map<String, Number> queryVector) {
350-
super(scoreScript, queryVector);
367+
public CosineSimilaritySparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
368+
super(scoreScript, queryVector, docVector);
351369
double dotProduct = 0;
352370
for (int i = 0; i< queryDims.length; i++) {
353371
dotProduct += queryValues[i] * queryValues[i];
354372
}
355373
this.queryVectorMagnitude = Math.sqrt(dotProduct);
356374
}
357375

358-
public double cosineSimilaritySparse(Object arg) {
359-
BytesRef vector = getEncodedVector(arg);
376+
public double cosineSimilaritySparse() {
377+
BytesRef vector = getEncodedVector();
360378
validateDocVector(vector);
361379

362380
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);

0 commit comments

Comments
 (0)