Skip to content

Commit e4c42b0

Browse files
committed
Add cosine_distance for sparse vectors
1 parent 4d2443d commit e4c42b0

File tree

3 files changed

+53
-0
lines changed

3 files changed

+53
-0
lines changed

core/trino-main/src/main/java/io/trino/operator/scalar/MathFunctions.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,6 +1408,30 @@ public static Double cosineSimilarity(
14081408
return dotProduct / (normLeftMap * normRightMap);
14091409
}
14101410

1411+
@Description("Calculates the cosine distance between the give sparse vectors")
1412+
@ScalarFunction
1413+
@SqlNullable
1414+
@SqlType(StandardTypes.DOUBLE)
1415+
public static Double cosineDistance(
1416+
@OperatorDependency(
1417+
operator = IDENTICAL,
1418+
argumentTypes = {"varchar", "varchar"},
1419+
convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) BlockPositionIsIdentical varcharIdentical,
1420+
@OperatorDependency(
1421+
operator = HASH_CODE,
1422+
argumentTypes = "varchar",
1423+
convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL)) BlockPositionHashCode varcharHashCode,
1424+
@SqlType("map(varchar,double)") SqlMap leftMap,
1425+
@SqlType("map(varchar,double)") SqlMap rightMap)
1426+
{
1427+
Double cosineSimilarity = cosineSimilarity(varcharIdentical, varcharHashCode, leftMap, rightMap);
1428+
if (cosineSimilarity == null) {
1429+
return null;
1430+
}
1431+
1432+
return 1.0 - cosineSimilarity;
1433+
}
1434+
14111435
private static double mapDotProduct(BlockPositionIsIdentical varcharIdentical, BlockPositionHashCode varcharHashCode, SqlMap leftMap, SqlMap rightMap)
14121436
{
14131437
int leftRawOffset = leftMap.getRawOffset();

core/trino-main/src/test/java/io/trino/operator/scalar/TestMathFunctions.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3457,6 +3457,25 @@ public void testCosineSimilarity()
34573457
.isNull(DOUBLE);
34583458
}
34593459

3460+
@Test
3461+
public void testCosineDistance()
3462+
{
3463+
assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b'], ARRAY[1.0E0, 2.0E0])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
3464+
.isEqualTo(1 - (2 * 3 / (Math.sqrt(5) * Math.sqrt(10))));
3465+
3466+
assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b', 'c'], ARRAY[1.0E0, 2.0E0, -1.0E0])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
3467+
.isEqualTo(1 - ((2 * 3 + -1 * 1) / (Math.sqrt(1 + 4 + 1) * Math.sqrt(1 + 9))));
3468+
3469+
assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b', 'c'], ARRAY[1.0E0, 2.0E0, -1.0E0])", "map(ARRAY['d', 'e'], ARRAY[1.0E0, 3.0E0])"))
3470+
.isEqualTo(1.0);
3471+
3472+
assertThat(assertions.function("cosine_distance", "null", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
3473+
.isNull();
3474+
3475+
assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b'], ARRAY[1.0E0, null])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
3476+
.isNull();
3477+
}
3478+
34603479
@Test
34613480
public void testInverseNormalCdf()
34623481
{

docs/src/main/sphinx/functions/math.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,16 @@ SELECT cosine_distance(ARRAY[1.0, 2.0], ARRAY[3.0, 4.0]);
205205
```
206206
:::
207207

208+
:::{function} cosine_distance(x, y) -> double
209+
:no-index:
210+
Calculates the cosine distance between two sparse vectors:
211+
212+
```sql
213+
SELECT cosine_distance(MAP(ARRAY['a'], ARRAY[1.0]), MAP(ARRAY['a'], ARRAY[2.0]));
214+
-- 0.0
215+
```
216+
:::
217+
208218
:::{function} cosine_similarity(array(double), array(double)) -> double
209219
Calculates the cosine similarity of two dense vectors:
210220

0 commit comments

Comments
 (0)