99
1010import org .apache .logging .log4j .LogManager ;
1111import org .apache .lucene .util .BytesRef ;
12+ import org .elasticsearch .ExceptionsHelper ;
1213import org .elasticsearch .Version ;
1314import org .elasticsearch .common .logging .DeprecationLogger ;
1415import org .elasticsearch .script .ScoreScript ;
1718import org .elasticsearch .xpack .vectors .query .VectorScriptDocValues .DenseVectorScriptDocValues ;
1819import org .elasticsearch .xpack .vectors .query .VectorScriptDocValues .SparseVectorScriptDocValues ;
1920
21+ import java .io .IOException ;
2022import java .nio .ByteBuffer ;
2123import java .util .List ;
2224import java .util .Map ;
@@ -37,9 +39,12 @@ public class ScoreScriptUtils {
3739 public static class DenseVectorFunction {
3840 final ScoreScript scoreScript ;
3941 final float [] queryVector ;
42+ final VectorScriptDocValues .DenseVectorScriptDocValues docValues ;
4043
41- public DenseVectorFunction (ScoreScript scoreScript , List <Number > queryVector ) {
42- this (scoreScript , queryVector , false );
44+ public DenseVectorFunction (ScoreScript scoreScript ,
45+ List <Number > queryVector ,
46+ Object field ) {
47+ this (scoreScript , queryVector , field , false );
4348 }
4449
4550 /**
@@ -51,6 +56,7 @@ public DenseVectorFunction(ScoreScript scoreScript, List<Number> queryVector) {
5156 */
5257 public DenseVectorFunction (ScoreScript scoreScript ,
5358 List <Number > queryVector ,
59+ Object field ,
5460 boolean normalizeQuery ) {
5561 this .scoreScript = scoreScript ;
5662
@@ -68,18 +74,22 @@ public DenseVectorFunction(ScoreScript scoreScript,
6874 this .queryVector [dim ] /= queryMagnitude ;
6975 }
7076 }
71- }
7277
73- BytesRef getEncodedVector (Object arg ) {
74- DenseVectorScriptDocValues docValues ;
75- if (arg instanceof DenseVectorScriptDocValues ) {
76- docValues = (DenseVectorScriptDocValues ) arg ;
78+ if (field instanceof DenseVectorScriptDocValues ) {
79+ docValues = (DenseVectorScriptDocValues ) field ;
7780 deprecationLogger .deprecatedAndMaybeLog ("vector_function_signature" , DEPRECATION_MESSAGE );
7881 } else {
79- String field = (String ) arg ;
80- docValues = (DenseVectorScriptDocValues ) scoreScript .getDoc ().get (field );
82+ String fieldName = (String ) field ;
83+ docValues = (DenseVectorScriptDocValues ) scoreScript .getDoc ().get (fieldName );
8184 }
85+ }
8286
87+ BytesRef getEncodedVector () {
88+ try {
89+ docValues .setNextDocId (scoreScript ._getDocId ());
90+ } catch (IOException e ) {
91+ throw ExceptionsHelper .convertToElastic (e );
92+ }
8393 return docValues .getEncodedValue ();
8494 }
8595
@@ -99,12 +109,12 @@ void validateDocVector(BytesRef vector) {
99109 // Calculate l1 norm (Manhattan distance) between a query's dense vector and documents' dense vectors
100110 public static final class L1Norm extends DenseVectorFunction {
101111
102- public L1Norm (ScoreScript scoreScript , List <Number > queryVector ) {
103- super (scoreScript , queryVector );
112+ public L1Norm (ScoreScript scoreScript , List <Number > queryVector , Object field ) {
113+ super (scoreScript , queryVector , field );
104114 }
105115
106- public double l1norm (Object arg ) {
107- BytesRef vector = getEncodedVector (arg );
116+ public double l1norm () {
117+ BytesRef vector = getEncodedVector ();
108118 validateDocVector (vector );
109119 ByteBuffer byteBuffer = ByteBuffer .wrap (vector .bytes , vector .offset , vector .length );
110120
@@ -120,12 +130,12 @@ public double l1norm(Object arg) {
120130 // Calculate l2 norm (Euclidean distance) between a query's dense vector and documents' dense vectors
121131 public static final class L2Norm extends DenseVectorFunction {
122132
123- public L2Norm (ScoreScript scoreScript , List <Number > queryVector ) {
124- super (scoreScript , queryVector );
133+ public L2Norm (ScoreScript scoreScript , List <Number > queryVector , Object field ) {
134+ super (scoreScript , queryVector , field );
125135 }
126136
127- public double l2norm (Object arg ) {
128- BytesRef vector = getEncodedVector (arg );
137+ public double l2norm () {
138+ BytesRef vector = getEncodedVector ();
129139 validateDocVector (vector );
130140 ByteBuffer byteBuffer = ByteBuffer .wrap (vector .bytes , vector .offset , vector .length );
131141
@@ -141,12 +151,12 @@ public double l2norm(Object arg) {
141151 // Calculate a dot product between a query's dense vector and documents' dense vectors
142152 public static final class DotProduct extends DenseVectorFunction {
143153
144- public DotProduct (ScoreScript scoreScript , List <Number > queryVector ) {
145- super (scoreScript , queryVector );
154+ public DotProduct (ScoreScript scoreScript , List <Number > queryVector , Object field ) {
155+ super (scoreScript , queryVector , field );
146156 }
147157
148- public double dotProduct (Object arg ) {
149- BytesRef vector = getEncodedVector (arg );
158+ public double dotProduct () {
159+ BytesRef vector = getEncodedVector ();
150160 validateDocVector (vector );
151161 ByteBuffer byteBuffer = ByteBuffer .wrap (vector .bytes , vector .offset , vector .length );
152162
@@ -161,12 +171,12 @@ public double dotProduct(Object arg) {
161171 // Calculate cosine similarity between a query's dense vector and documents' dense vectors
162172 public static final class CosineSimilarity extends DenseVectorFunction {
163173
164- public CosineSimilarity (ScoreScript scoreScript , List <Number > queryVector ) {
165- super (scoreScript , queryVector , true );
174+ public CosineSimilarity (ScoreScript scoreScript , List <Number > queryVector , Object field ) {
175+ super (scoreScript , queryVector , field , true );
166176 }
167177
168- public double cosineSimilarity (Object arg ) {
169- BytesRef vector = getEncodedVector (arg );
178+ public double cosineSimilarity () {
179+ BytesRef vector = getEncodedVector ();
170180 validateDocVector (vector );
171181 ByteBuffer byteBuffer = ByteBuffer .wrap (vector .bytes , vector .offset , vector .length );
172182
@@ -199,9 +209,13 @@ public static class SparseVectorFunction {
199209 final float [] queryValues ;
200210 final int [] queryDims ;
201211
212+ final VectorScriptDocValues .SparseVectorScriptDocValues docValues ;
213+
202214 // prepare queryVector once per script execution
203215 // queryVector represents a map of dimensions to values
204- public SparseVectorFunction (ScoreScript scoreScript , Map <String , Number > queryVector ) {
216+ public SparseVectorFunction (ScoreScript scoreScript ,
217+ Map <String , Number > queryVector ,
218+ Object field ) {
205219 this .scoreScript = scoreScript ;
206220 //break vector into two arrays dims and values
207221 int n = queryVector .size ();
@@ -220,19 +234,23 @@ public SparseVectorFunction(ScoreScript scoreScript, Map<String, Number> queryVe
220234 // Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions
221235 sortSparseDimsFloatValues (queryDims , queryValues , n );
222236
223- deprecationLogger .deprecatedAndMaybeLog ("sparse_vector_function" , SparseVectorFieldMapper .DEPRECATION_MESSAGE );
224- }
225-
226- BytesRef getEncodedVector (Object arg ) {
227- SparseVectorScriptDocValues docValues ;
228- if (arg instanceof SparseVectorScriptDocValues ) {
229- docValues = (SparseVectorScriptDocValues ) arg ;
237+ if (field instanceof SparseVectorScriptDocValues ) {
238+ docValues = (SparseVectorScriptDocValues ) field ;
230239 deprecationLogger .deprecatedAndMaybeLog ("vector_function_signature" , DEPRECATION_MESSAGE );
231240 } else {
232- String field = (String ) arg ;
233- docValues = (SparseVectorScriptDocValues ) scoreScript .getDoc ().get (field );
241+ String fieldName = (String ) field ;
242+ docValues = (SparseVectorScriptDocValues ) scoreScript .getDoc ().get (fieldName );
234243 }
235244
245+ deprecationLogger .deprecatedAndMaybeLog ("sparse_vector_function" , SparseVectorFieldMapper .DEPRECATION_MESSAGE );
246+ }
247+
248+ BytesRef getEncodedVector () {
249+ try {
250+ docValues .setNextDocId (scoreScript ._getDocId ());
251+ } catch (IOException e ) {
252+ throw ExceptionsHelper .convertToElastic (e );
253+ }
236254 return docValues .getEncodedValue ();
237255 }
238256
@@ -245,12 +263,12 @@ public void validateDocVector(BytesRef vector) {
245263
246264 // Calculate l1 norm (Manhattan distance) between a query's sparse vector and documents' sparse vectors
247265 public static final class L1NormSparse extends SparseVectorFunction {
248- public L1NormSparse (ScoreScript scoreScript ,Map <String , Number > queryVector ) {
249- super (scoreScript , queryVector );
266+ public L1NormSparse (ScoreScript scoreScript ,Map <String , Number > queryVector , Object docVector ) {
267+ super (scoreScript , queryVector , docVector );
250268 }
251269
252- public double l1normSparse (Object arg ) {
253- BytesRef vector = getEncodedVector (arg );
270+ public double l1normSparse () {
271+ BytesRef vector = getEncodedVector ();
254272 validateDocVector (vector );
255273
256274 int [] docDims = VectorEncoderDecoder .decodeSparseVectorDims (scoreScript ._getIndexVersion (), vector );
@@ -285,12 +303,12 @@ public double l1normSparse(Object arg) {
285303
286304 // Calculate l2 norm (Euclidean distance) between a query's sparse vector and documents' sparse vectors
287305 public static final class L2NormSparse extends SparseVectorFunction {
288- public L2NormSparse (ScoreScript scoreScript , Map <String , Number > queryVector ) {
289- super (scoreScript , queryVector );
306+ public L2NormSparse (ScoreScript scoreScript , Map <String , Number > queryVector , Object docVector ) {
307+ super (scoreScript , queryVector , docVector );
290308 }
291309
292- public double l2normSparse (Object arg ) {
293- BytesRef vector = getEncodedVector (arg );
310+ public double l2normSparse () {
311+ BytesRef vector = getEncodedVector ();
294312 validateDocVector (vector );
295313
296314 int [] docDims = VectorEncoderDecoder .decodeSparseVectorDims (scoreScript ._getIndexVersion (), vector );
@@ -328,12 +346,12 @@ public double l2normSparse(Object arg) {
328346
329347 // Calculate a dot product between a query's sparse vector and documents' sparse vectors
330348 public static final class DotProductSparse extends SparseVectorFunction {
331- public DotProductSparse (ScoreScript scoreScript , Map <String , Number > queryVector ) {
332- super (scoreScript , queryVector );
349+ public DotProductSparse (ScoreScript scoreScript , Map <String , Number > queryVector , Object docVector ) {
350+ super (scoreScript , queryVector , docVector );
333351 }
334352
335- public double dotProductSparse (Object arg ) {
336- BytesRef vector = getEncodedVector (arg );
353+ public double dotProductSparse () {
354+ BytesRef vector = getEncodedVector ();
337355 validateDocVector (vector );
338356
339357 int [] docDims = VectorEncoderDecoder .decodeSparseVectorDims (scoreScript ._getIndexVersion (), vector );
@@ -346,17 +364,17 @@ public double dotProductSparse(Object arg) {
346364 public static final class CosineSimilaritySparse extends SparseVectorFunction {
347365 final double queryVectorMagnitude ;
348366
349- public CosineSimilaritySparse (ScoreScript scoreScript , Map <String , Number > queryVector ) {
350- super (scoreScript , queryVector );
367+ public CosineSimilaritySparse (ScoreScript scoreScript , Map <String , Number > queryVector , Object docVector ) {
368+ super (scoreScript , queryVector , docVector );
351369 double dotProduct = 0 ;
352370 for (int i = 0 ; i < queryDims .length ; i ++) {
353371 dotProduct += queryValues [i ] * queryValues [i ];
354372 }
355373 this .queryVectorMagnitude = Math .sqrt (dotProduct );
356374 }
357375
358- public double cosineSimilaritySparse (Object arg ) {
359- BytesRef vector = getEncodedVector (arg );
376+ public double cosineSimilaritySparse () {
377+ BytesRef vector = getEncodedVector ();
360378 validateDocVector (vector );
361379
362380 int [] docDims = VectorEncoderDecoder .decodeSparseVectorDims (scoreScript ._getIndexVersion (), vector );
0 commit comments