Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions presto-docs/src/main/sphinx/functions/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,20 @@ Array Functions

Returns the minimum value of input array.

.. function:: array_max_by(array(T), function(T, U)) -> T

Applies the provided function to each element, and returns the element that gives the maximum value.
``U`` can be any orderable type. ::

SELECT array_max_by(ARRAY ['a', 'bbb', 'cc'], x -> LENGTH(x)) -- 'bbb'

.. function:: array_min_by(array(T), function(T, U)) -> T

Applies the provided function to each element, and returns the element that gives the minimum value.
``U`` can be any orderable type. ::

SELECT array_min_by(ARRAY ['a', 'bbb', 'cc'], x -> LENGTH(x)) -- 'a'

.. function:: array_normalize(x, p) -> array

Normalizes array ``x`` by dividing each element by the p-norm of the array.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.SqlInvokedScalarFunction;
import com.facebook.presto.spi.function.SqlParameter;
import com.facebook.presto.spi.function.SqlParameters;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;

Expand Down Expand Up @@ -59,7 +60,7 @@ public static String arrayAverage()
@TypeParameter("T")
@SqlParameter(name = "input", type = "array(T)")
@SqlType("map(T, int)")
public static String arrayFrequencyBigint()
public static String arrayFrequency()
{
return "RETURN reduce(" +
"input," +
Expand Down Expand Up @@ -89,4 +90,30 @@ public static String arrayHasDuplicatesVarchar()
{
return "RETURN cardinality(array_duplicates(input)) > 0";
}

@SqlInvokedScalarFunction(value = "array_max_by", deterministic = true, calledOnNullInput = true)
@Description("Get the maximum value of array, by using a specific transformation function")
@TypeParameter("T")
@TypeParameter("U")
@SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "f", type = "function(T, U)")})
@SqlType("T")
public static String arrayMaxBy()
{
return "RETURN input[" +
"array_max(zip_with(transform(input, f), sequence(1, cardinality(input)), (x, y)->IF(x IS NULL, NULL, (x, y))))[2]" +
"]";
}

@SqlInvokedScalarFunction(value = "array_min_by", deterministic = true, calledOnNullInput = true)
@Description("Get the minimum value of array, by using a specific transformation function")
@TypeParameter("T")
@TypeParameter("U")
@SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "f", type = "function(T, U)")})
@SqlType("T")
public static String arrayMinBy()
{
return "RETURN input[" +
"array_min(zip_with(transform(input, f), sequence(1, cardinality(input)), (x, y)->IF(x IS NULL, NULL, (x, y))))[2]" +
"]";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.common.type.VarcharType.createVarcharType;
import static com.facebook.presto.util.StructuralTestUtil.mapType;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;

public class TestArraySqlFunctions
Expand Down Expand Up @@ -181,4 +184,34 @@ public void testArrayDuplicates()
assertInvalidFunction("array_duplicates(array[(1, null), (null, 2), (null, 1)])", StandardErrorCode.NOT_SUPPORTED, "ROW comparison not supported for fields with null elements");
assertInvalidFunction("array_duplicates(array[(1, null), (null, 2), (null, null)])", StandardErrorCode.NOT_SUPPORTED, "map key cannot be null or contain nulls");
}

@Test
public void testArrayMaxBy()
{
assertFunction("ARRAY_MAX_BY(ARRAY [double'1.0', double'2.0'], i -> i)", DOUBLE, 2.0d);
assertFunction("ARRAY_MAX_BY(ARRAY [double'-3.0', double'2.0'], i -> i*i)", DOUBLE, -3.0d);
assertFunction("ARRAY_MAX_BY(ARRAY ['a', 'bb', 'c'], x -> LENGTH(x))", createVarcharType(2), "bb");
assertFunction("ARRAY_MAX_BY(ARRAY [1, 2, 3], x -> 1-x)", INTEGER, 1);
assertFunction("ARRAY_MAX_BY(ARRAY [ARRAY['a'], ARRAY['b', 'b'], ARRAY['c']], x -> CARDINALITY(x))", new ArrayType(createVarcharType(1)), asList("b", "b"));
assertFunction("ARRAY_MAX_BY(ARRAY [MAP(ARRAY['foo', 'bar'], ARRAY[1, 2]), MAP(ARRAY['foo', 'bar'], ARRAY[0, 3])], x -> x['foo'])", mapType(createVarcharType(3), INTEGER), ImmutableMap.of("foo", 1, "bar", 2));
assertFunction("ARRAY_MAX_BY(ARRAY [CAST(ROW(0, 2.0) AS ROW(x BIGINT, y DOUBLE)), CAST(ROW(1, 3.0) AS ROW(x BIGINT, y DOUBLE))], r -> r.y).x", BIGINT, 1L);
assertFunction("ARRAY_MAX_BY(ARRAY [null, double'1.0', double'2.0'], i -> i)", DOUBLE, null);
assertFunction("ARRAY_MAX_BY(ARRAY [cast(null as double), cast(null as double)], i -> i)", DOUBLE, null);
assertFunction("ARRAY_MAX_BY(cast(null as array(double)), i -> i)", DOUBLE, null);
}

@Test
public void testArrayMinBy()
{
assertFunction("ARRAY_MIN_BY(ARRAY [double'1.0', double'2.0'], i -> i)", DOUBLE, 1.0d);
assertFunction("ARRAY_MIN_BY(ARRAY [double'-3.0', double'2.0'], i -> i*i)", DOUBLE, 2.0d);
assertFunction("ARRAY_MIN_BY(ARRAY ['a', 'bb', 'c'], x -> LENGTH(x))", createVarcharType(2), "a");
assertFunction("ARRAY_MIN_BY(ARRAY [1, 2, 3], x -> 1-x)", INTEGER, 3);
assertFunction("ARRAY_MIN_BY(ARRAY [ARRAY['a'], ARRAY['b', 'b'], ARRAY['c']], x -> CARDINALITY(x))", new ArrayType(createVarcharType(1)), singletonList("a"));
assertFunction("ARRAY_MIN_BY(ARRAY [MAP(ARRAY['foo', 'bar'], ARRAY[1, 2]), MAP(ARRAY['foo', 'bar'], ARRAY[0, 3])], x -> x['foo'])", mapType(createVarcharType(3), INTEGER), ImmutableMap.of("foo", 0, "bar", 3));
assertFunction("ARRAY_MIN_BY(ARRAY [CAST(ROW(0, 2.0) AS ROW(x BIGINT, y DOUBLE)), CAST(ROW(1, 3.0) AS ROW(x BIGINT, y DOUBLE))], r -> r.y).x", BIGINT, 0L);
assertFunction("ARRAY_MIN_BY(ARRAY [null, double'1.0', double'2.0'], i -> i)", DOUBLE, null);
assertFunction("ARRAY_MIN_BY(ARRAY [cast(null as double), cast(null as double)], i -> i)", DOUBLE, null);
assertFunction("ARRAY_MIN_BY(cast(null as array(double)), i -> i)", DOUBLE, null);
}
}