Skip to content

Commit ed95d9f

Browse files
committed
Combine fetching and validation of the doc vectors.
1 parent 62e9ced commit ed95d9f

File tree

1 file changed

+8
-18
lines changed

1 file changed

+8
-18
lines changed

x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,9 @@ BytesRef getEncodedVector() {
9090
} catch (IOException e) {
9191
throw ExceptionsHelper.convertToElastic(e);
9292
}
93-
return docValues.getEncodedValue();
94-
}
9593

96-
void validateDocVector(BytesRef vector) {
94+
// Validate the encoded vector's length.
95+
BytesRef vector = docValues.getEncodedValue();
9796
if (vector == null) {
9897
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
9998
}
@@ -103,6 +102,7 @@ void validateDocVector(BytesRef vector) {
103102
throw new IllegalArgumentException("The query vector has a different number of dimensions [" +
104103
queryVector.length + "] than the document vectors [" + vectorLength + "].");
105104
}
105+
return vector;
106106
}
107107
}
108108

@@ -115,7 +115,6 @@ public L1Norm(ScoreScript scoreScript, List<Number> queryVector, Object field) {
115115

116116
public double l1norm() {
117117
BytesRef vector = getEncodedVector();
118-
validateDocVector(vector);
119118
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
120119

121120
double l1norm = 0;
@@ -136,7 +135,6 @@ public L2Norm(ScoreScript scoreScript, List<Number> queryVector, Object field) {
136135

137136
public double l2norm() {
138137
BytesRef vector = getEncodedVector();
139-
validateDocVector(vector);
140138
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
141139

142140
double l2norm = 0;
@@ -157,7 +155,6 @@ public DotProduct(ScoreScript scoreScript, List<Number> queryVector, Object fiel
157155

158156
public double dotProduct() {
159157
BytesRef vector = getEncodedVector();
160-
validateDocVector(vector);
161158
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
162159

163160
double dotProduct = 0;
@@ -177,7 +174,6 @@ public CosineSimilarity(ScoreScript scoreScript, List<Number> queryVector, Objec
177174

178175
public double cosineSimilarity() {
179176
BytesRef vector = getEncodedVector();
180-
validateDocVector(vector);
181177
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
182178

183179
double dotProduct = 0.0;
@@ -251,13 +247,12 @@ BytesRef getEncodedVector() {
251247
} catch (IOException e) {
252248
throw ExceptionsHelper.convertToElastic(e);
253249
}
254-
return docValues.getEncodedValue();
255-
}
256250

257-
public void validateDocVector(BytesRef vector) {
251+
BytesRef vector = docValues.getEncodedValue();
258252
if (vector == null) {
259253
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
260254
}
255+
return vector;
261256
}
262257
}
263258

@@ -269,10 +264,9 @@ public L1NormSparse(ScoreScript scoreScript,Map<String, Number> queryVector, Obj
269264

270265
public double l1normSparse() {
271266
BytesRef vector = getEncodedVector();
272-
validateDocVector(vector);
273-
274267
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
275268
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
269+
276270
int queryIndex = 0;
277271
int docIndex = 0;
278272
double l1norm = 0;
@@ -309,10 +303,9 @@ public L2NormSparse(ScoreScript scoreScript, Map<String, Number> queryVector, Ob
309303

310304
public double l2normSparse() {
311305
BytesRef vector = getEncodedVector();
312-
validateDocVector(vector);
313-
314306
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
315307
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
308+
316309
int queryIndex = 0;
317310
int docIndex = 0;
318311
double l2norm = 0;
@@ -352,10 +345,9 @@ public DotProductSparse(ScoreScript scoreScript, Map<String, Number> queryVector
352345

353346
public double dotProductSparse() {
354347
BytesRef vector = getEncodedVector();
355-
validateDocVector(vector);
356-
357348
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
358349
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
350+
359351
return intDotProductSparse(queryValues, queryDims, docValues, docDims);
360352
}
361353
}
@@ -375,8 +367,6 @@ public CosineSimilaritySparse(ScoreScript scoreScript, Map<String, Number> query
375367

376368
public double cosineSimilaritySparse() {
377369
BytesRef vector = getEncodedVector();
378-
validateDocVector(vector);
379-
380370
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
381371
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
382372

0 commit comments

Comments
 (0)