Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.trino.spi.block.Block;
import io.trino.spi.function.Description;
import io.trino.spi.function.ScalarFunction;
import io.trino.spi.function.SqlNullable;
import io.trino.spi.function.SqlType;
import io.trino.spi.type.StandardTypes;

Expand Down Expand Up @@ -59,10 +60,15 @@ public static double dotProduct(@SqlType("array(double)") Block first, @SqlType(
@Description("Calculates the cosine similarity between two vectors")
@ScalarFunction
@SqlType(StandardTypes.DOUBLE)
public static double cosineSimilarity(@SqlType("array(double)") Block first, @SqlType("array(double)") Block second)
@SqlNullable
public static Double cosineSimilarity(@SqlType("array(double)") Block first, @SqlType("array(double)") Block second)
{
checkCondition(first.getPositionCount() == second.getPositionCount(), INVALID_FUNCTION_ARGUMENT, "The arguments must have the same length");

if (first.hasNull() || second.hasNull()) {
return null;
}

double firstMagnitude = 0.0;
double secondMagnitude = 0.0;
double dotProduct = 0.0;
Expand All @@ -81,8 +87,13 @@ public static double cosineSimilarity(@SqlType("array(double)") Block first, @Sq
@Description("Calculates the cosine distance between two vectors")
@ScalarFunction
@SqlType(StandardTypes.DOUBLE)
public static double cosineDistance(@SqlType("array(double)") Block first, @SqlType("array(double)") Block second)
@SqlNullable
public static Double cosineDistance(@SqlType("array(double)") Block first, @SqlType("array(double)") Block second)
{
return 1.0 - cosineSimilarity(first, second);
Double cosineSimilarity = cosineSimilarity(first, second);
if (cosineSimilarity == null) {
return null;
}
return 1.0 - cosineSimilarity;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.IntStream;

Expand Down Expand Up @@ -198,6 +199,13 @@ private <T> void assertBlockPositions(Block block, T[] expectedValues)
for (int position = 0; position < block.getPositionCount(); position++) {
assertBlockPosition(block, position, expectedValues[position]);
}
if (Arrays.stream(expectedValues).anyMatch(Objects::isNull)) {
assertThat(block.hasNull()).isTrue();
assertThat(block.mayHaveNull()).isTrue();
}
else {
assertThat(block.hasNull()).isFalse();
}
}

protected static List<Block> splitBlock(Block block, int count)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,11 @@ void testCosineDistance()
.hasType(DOUBLE)
.isEqualTo(NaN);

assertThat(assertions.function("cosine_distance", "ARRAY[1, 2]", "ARRAY[3, null]"))
.isNull(DOUBLE);
assertThat(assertions.function("cosine_distance", "ARRAY[1, null]", "ARRAY[3, 4]"))
.isNull(DOUBLE);

assertTrinoExceptionThrownBy(assertions.function("cosine_distance", "ARRAY[]", "ARRAY[]")::evaluate)
.hasMessage("Vector magnitude cannot be zero");
assertTrinoExceptionThrownBy(assertions.function("cosine_distance", "ARRAY[]", "ARRAY[1]")::evaluate)
Expand All @@ -341,4 +346,109 @@ void testCosineDistance()
assertTrinoExceptionThrownBy(assertions.function("cosine_distance", "ARRAY[1, 2]", "ARRAY[1]")::evaluate)
.hasMessage("The arguments must have the same length");
}

@Test
void testCosineSimilarity()
{
assertThat(assertions.function("cosine_similarity", "ARRAY[1]", "ARRAY[2]"))
.hasType(DOUBLE)
.isEqualTo(1.0);
assertThat(assertions.function("cosine_similarity", "ARRAY[1, 2]", "ARRAY[3, 4]"))
.hasType(DOUBLE)
.isEqualTo(1.0 - 0.01613008990009257);
assertThat(assertions.function("cosine_similarity", "ARRAY[4, 5, 6]", "ARRAY[4, 5, 6]"))
.hasType(DOUBLE)
.isEqualTo(1.0);
assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '1.1', REAL '2.2', REAL '3.3']", "ARRAY[REAL '4.4', REAL '5.5', REAL '6.6']"))
.hasType(DOUBLE)
.isEqualTo(1.0 - 0.025368154060122383);
assertThat(assertions.function("cosine_similarity", "ARRAY[DOUBLE '1.1', DOUBLE '2.2', DOUBLE '3.3']", "ARRAY[DOUBLE '4.4', DOUBLE '5.5', DOUBLE '6.6']"))
.hasType(DOUBLE)
.isEqualTo(1.0 - 0.025368153802923676);
assertThat(assertions.function("cosine_similarity", "ARRAY[1.1, 2.2, 3.3]", "ARRAY[4.4, 5.5, 6.6]"))
.hasType(DOUBLE)
.isEqualTo(1.0 - 0.025368153802923676);

// real type's min and max
assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '3.4028235e+38f']", "ARRAY[REAL '3.4028235e+38f']"))
.hasType(DOUBLE)
.isEqualTo(1.0);
assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '-3.4028235e+38f']", "ARRAY[REAL '-3.4028235e+38f']"))
.hasType(DOUBLE)
.isEqualTo(1.0);
assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '3.4028235e+38f']", "ARRAY[REAL '-3.4028235e+38f']"))
.hasType(DOUBLE)
.isEqualTo(1.0 - 2.0);
assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '-3.4028235e+38f']", "ARRAY[REAL '3.4028235e+38f']"))
.hasType(DOUBLE)
.isEqualTo(1.0 - 2.0);
assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '1.4E-45']", "ARRAY[REAL '1.4E-45']"))
.hasType(DOUBLE)
.isEqualTo(1.0);
assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '-1.4E-45']", "ARRAY[REAL '-1.4E-45']"))
.hasType(DOUBLE)
.isEqualTo(1.0);
assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '1.4E-45']", "ARRAY[REAL '-1.4E-45']"))
.hasType(DOUBLE)
.isEqualTo(1.0 - 2.0);
assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '-1.4E-45']", "ARRAY[REAL '1.4E-45']"))
.hasType(DOUBLE)
.isEqualTo(1.0 - 2.0);
assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '3.4028235e+38f']", "ARRAY[REAL '1.4E-45']"))
.hasType(DOUBLE)
.isEqualTo(1.0);
assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '1.4E-45']", "ARRAY[REAL '3.4028235e+38f']"))
.hasType(DOUBLE)
.isEqualTo(1.0);

// double type's min and max
assertThat(assertions.function("cosine_similarity", "ARRAY[DOUBLE '1.7976931348623157E+309']", "ARRAY[DOUBLE '1.7976931348623157E+309']"))
.hasType(DOUBLE)
.isEqualTo(NaN);
assertThat(assertions.function("cosine_similarity", "ARRAY[DOUBLE '-1.7976931348623157E+308']", "ARRAY[DOUBLE '-1.7976931348623157E+308']"))
.hasType(DOUBLE)
.isEqualTo(NaN);
assertThat(assertions.function("cosine_similarity", "ARRAY[DOUBLE '1.7976931348623157E+309']", "ARRAY[DOUBLE '-1.7976931348623157E+308']"))
.hasType(DOUBLE)
.isEqualTo(NaN);
assertThat(assertions.function("cosine_similarity", "ARRAY[DOUBLE '-1.7976931348623157E+308']", "ARRAY[DOUBLE '1.7976931348623157E+309']"))
.hasType(DOUBLE)
.isEqualTo(NaN);

// NaN and infinity
assertThat(assertions.function("cosine_similarity", "ARRAY[1]", "ARRAY[nan()]"))
.hasType(DOUBLE)
.isEqualTo(NaN);
assertThat(assertions.function("cosine_similarity", "ARRAY[nan()]", "ARRAY[1]"))
.hasType(DOUBLE)
.isEqualTo(NaN);
assertThat(assertions.function("cosine_similarity", "ARRAY[1]", "ARRAY[-infinity()]"))
.hasType(DOUBLE)
.isEqualTo(NaN);
assertThat(assertions.function("cosine_similarity", "ARRAY[-infinity()]", "ARRAY[1]"))
.hasType(DOUBLE)
.isEqualTo(NaN);
assertThat(assertions.function("cosine_similarity", "ARRAY[1]", "ARRAY[infinity()]"))
.hasType(DOUBLE)
.isEqualTo(NaN);
assertThat(assertions.function("cosine_similarity", "ARRAY[infinity()]", "ARRAY[1]"))
.hasType(DOUBLE)
.isEqualTo(NaN);

assertThat(assertions.function("cosine_similarity", "ARRAY[1, 2]", "ARRAY[3, null]"))
.isNull(DOUBLE);
assertThat(assertions.function("cosine_similarity", "ARRAY[1, null]", "ARRAY[3, 4]"))
.isNull(DOUBLE);

assertTrinoExceptionThrownBy(assertions.function("cosine_similarity", "ARRAY[]", "ARRAY[]")::evaluate)
.hasMessage("Vector magnitude cannot be zero");
assertTrinoExceptionThrownBy(assertions.function("cosine_similarity", "ARRAY[]", "ARRAY[1]")::evaluate)
.hasMessage("The arguments must have the same length");
assertTrinoExceptionThrownBy(assertions.function("cosine_similarity", "ARRAY[1]", "ARRAY[]")::evaluate)
.hasMessage("The arguments must have the same length");
assertTrinoExceptionThrownBy(assertions.function("cosine_similarity", "ARRAY[1]", "ARRAY[1, 2]")::evaluate)
.hasMessage("The arguments must have the same length");
assertTrinoExceptionThrownBy(assertions.function("cosine_similarity", "ARRAY[1, 2]", "ARRAY[1]")::evaluate)
.hasMessage("The arguments must have the same length");
}
}
14 changes: 14 additions & 0 deletions core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,20 @@ public boolean mayHaveNull()
return valueIsNull != null;
}

@Override
public boolean hasNull()
{
if (valueIsNull == null) {
return false;
}
for (int i = 0; i < positionCount; i++) {
if (valueIsNull[i + arrayOffset]) {
return true;
}
}
return false;
}

@Override
public String toString()
{
Expand Down
13 changes: 13 additions & 0 deletions core/trino-spi/src/main/java/io/trino/spi/block/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,19 @@ default boolean mayHaveNull()
return true;
}

/**
* Does this block have a null value? This method is expected to be O(N).
*/
default boolean hasNull()
{
for (int i = 0; i < getPositionCount(); i++) {
if (isNull(i)) {
return true;
}
}
return false;
}

/**
* Is the specified position null?
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,20 @@ public boolean mayHaveNull()
return valueIsNull != null;
}

@Override
public boolean hasNull()
{
if (valueIsNull == null) {
return false;
}
for (int i = 0; i < positionCount; i++) {
if (valueIsNull[i + arrayOffset]) {
return true;
}
}
return false;
}

@Override
public boolean isNull(int position)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,12 @@ public boolean mayHaveNull()
return mayHaveNull && dictionary.mayHaveNull();
}

@Override
public boolean hasNull()
{
return mayHaveNull && dictionary.hasNull();
}

@Override
public boolean isNull(int position)
{
Expand Down
14 changes: 14 additions & 0 deletions core/trino-spi/src/main/java/io/trino/spi/block/Fixed12Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,20 @@ public boolean mayHaveNull()
return valueIsNull != null;
}

@Override
public boolean hasNull()
{
if (valueIsNull == null) {
return false;
}
for (int i = 0; i < positionCount; i++) {
if (valueIsNull[i + positionOffset]) {
return true;
}
}
return false;
}

@Override
public boolean isNull(int position)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,20 @@ public boolean mayHaveNull()
return valueIsNull != null;
}

@Override
public boolean hasNull()
{
if (valueIsNull == null) {
return false;
}
for (int i = 0; i < positionCount; i++) {
if (valueIsNull[i + positionOffset]) {
return true;
}
}
return false;
}

@Override
public boolean isNull(int position)
{
Expand Down
14 changes: 14 additions & 0 deletions core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,20 @@ public boolean mayHaveNull()
return valueIsNull != null;
}

@Override
public boolean hasNull()
{
if (valueIsNull == null) {
return false;
}
for (int i = 0; i < positionCount; i++) {
if (valueIsNull[i + arrayOffset]) {
return true;
}
}
return false;
}

@Override
public boolean isNull(int position)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ public boolean mayHaveNull()
return getBlock().mayHaveNull();
}

@Override
public boolean hasNull()
{
return getBlock().hasNull();
}

public Block getBlock()
{
return lazyData.getTopLevelBlock();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,20 @@ public boolean mayHaveNull()
return valueIsNull != null;
}

@Override
public boolean hasNull()
{
if (valueIsNull == null) {
return false;
}
for (int i = 0; i < positionCount; i++) {
if (valueIsNull[i + arrayOffset]) {
return true;
}
}
return false;
}

@Override
public boolean isNull(int position)
{
Expand Down
14 changes: 14 additions & 0 deletions core/trino-spi/src/main/java/io/trino/spi/block/MapBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,20 @@ public boolean mayHaveNull()
return mapIsNull != null;
}

@Override
public boolean hasNull()
{
if (mapIsNull == null) {
return false;
}
for (int i = 0; i < positionCount; i++) {
if (mapIsNull[i + startOffset]) {
return true;
}
}
return false;
}

@Override
public int getPositionCount()
{
Expand Down
14 changes: 14 additions & 0 deletions core/trino-spi/src/main/java/io/trino/spi/block/RowBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,20 @@ public boolean mayHaveNull()
return rowIsNull != null;
}

@Override
public boolean hasNull()
{
if (rowIsNull == null) {
return false;
}
for (int i = 0; i < positionCount; i++) {
if (rowIsNull[i]) {
return true;
}
}
return false;
}

boolean[] getRawRowIsNull()
{
return rowIsNull;
Expand Down
Loading