@@ -96,16 +96,19 @@ public static float l2Squared(List<Number> queryVector, KNNVectorScriptDocValues
96
96
* @return cosine score
97
97
*/
98
98
public static float cosinesimilOptimized (float [] queryVector , float [] inputVector , float normQueryVector ) {
99
+ requireEqualDimension (queryVector , inputVector );
100
+ float dotProduct = 0.0f ;
99
101
float normInputVector = 0.0f ;
100
102
for (int i = 0 ; i < queryVector .length ; i ++) {
103
+ dotProduct += queryVector [i ] * inputVector [i ];
101
104
normInputVector += inputVector [i ] * inputVector [i ];
102
105
}
103
106
float normalizedProduct = normQueryVector * normInputVector ;
104
107
if (normalizedProduct == 0 ) {
105
108
logger .debug ("Invalid vectors for cosine. Returning minimum score to put this result to end" );
106
109
return 0.0f ;
107
110
}
108
- return (float ) (VectorUtil . dotProduct ( queryVector , inputVector ) / (Math .sqrt (normalizedProduct )));
111
+ return (float ) (dotProduct / (Math .sqrt (normalizedProduct )));
109
112
}
110
113
111
114
/**
@@ -140,28 +143,12 @@ public static float cosineSimilarity(List<Number> queryVector, KNNVectorScriptDo
140
143
*/
141
144
public static float cosinesimil (float [] queryVector , float [] inputVector ) {
142
145
requireEqualDimension (queryVector , inputVector );
143
- int numZeroInInput = 0 ;
144
- int numZeroInQuery = 0 ;
145
- float cosine = 0.0f ;
146
- for (int i = 0 ; i < inputVector .length ; i ++) {
147
- if (inputVector [i ] == 0 ) {
148
- numZeroInInput ++;
149
- }
150
-
151
- if (queryVector [i ] == 0 ) {
152
- numZeroInQuery ++;
153
- }
154
- }
155
- if (numZeroInInput == inputVector .length || numZeroInQuery == queryVector .length ) {
156
- return cosine ;
157
- }
158
146
try {
159
- cosine = VectorUtil .cosine (queryVector , inputVector );
160
- } catch (IllegalArgumentException e ) {
147
+ return VectorUtil .cosine (queryVector , inputVector );
148
+ } catch (IllegalArgumentException | AssertionError e ) {
161
149
logger .debug ("Invalid vectors for cosine. Returning minimum score to put this result to end" );
162
150
return 0.0f ;
163
151
}
164
- return cosine ;
165
152
}
166
153
167
154
/**
0 commit comments