Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@
import static io.prestosql.spi.function.InvocationConvention.simpleConvention;
import static io.prestosql.spi.function.OperatorType.COMPARISON;
import static io.prestosql.util.Failures.internalError;
import static io.prestosql.util.MinMaxCompare.getMinMaxCompare;
import static io.prestosql.util.Reflection.methodHandle;
import static java.lang.invoke.MethodHandles.filterReturnValue;

public abstract class AbstractMinMaxAggregationFunction
extends SqlAggregationFunction
Expand All @@ -79,10 +79,7 @@ public abstract class AbstractMinMaxAggregationFunction
private static final MethodHandle BOOLEAN_COMBINE_FUNCTION = methodHandle(AbstractMinMaxAggregationFunction.class, "combine", MethodHandle.class, NullableBooleanState.class, NullableBooleanState.class);
private static final MethodHandle BLOCK_POSITION_COMBINE_FUNCTION = methodHandle(AbstractMinMaxAggregationFunction.class, "combine", MethodHandle.class, BlockPositionState.class, BlockPositionState.class);

private static final MethodHandle MIN_FUNCTION = methodHandle(AbstractMinMaxAggregationFunction.class, "min", long.class);
private static final MethodHandle MAX_FUNCTION = methodHandle(AbstractMinMaxAggregationFunction.class, "max", long.class);

private final MethodHandle comparisonResultAdapter;
private final boolean min;

protected AbstractMinMaxAggregationFunction(String name, boolean min, String description)
{
Expand All @@ -103,7 +100,7 @@ protected AbstractMinMaxAggregationFunction(String name, boolean min, String des
AGGREGATE),
true,
false);
this.comparisonResultAdapter = min ? MIN_FUNCTION : MAX_FUNCTION;
this.min = min;
}

@Override
Expand Down Expand Up @@ -142,8 +139,9 @@ public InternalAggregationFunction specialize(FunctionBinding functionBinding, F
else {
invocationConvention = simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION);
}
MethodHandle compareMethodHandle = functionDependencies.getOperatorInvoker(COMPARISON, ImmutableList.of(type, type), Optional.of(invocationConvention)).getMethodHandle();
compareMethodHandle = filterReturnValue(compareMethodHandle, comparisonResultAdapter);

MethodHandle compareMethodHandle = getMinMaxCompare(functionDependencies, type, Optional.of(invocationConvention), min);

return generateAggregation(type, compareMethodHandle);
}

Expand Down Expand Up @@ -341,16 +339,4 @@ private static void compareAndUpdateState(MethodHandle methodHandle, BlockPositi
throw internalError(t);
}
}

@UsedByGeneratedCode
public static boolean min(long comparisonResult)
{
return comparisonResult < 0;
}

@UsedByGeneratedCode
public static boolean max(long comparisonResult)
{
return comparisonResult > 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@

import io.prestosql.type.BlockTypeOperators;

import static io.prestosql.util.MinMaxCompare.getMaxCompare;

public class MaxNAggregationFunction
extends AbstractMinMaxNAggregationFunction
{
private static final String NAME = "max";

public MaxNAggregationFunction(BlockTypeOperators blockTypeOperators)
{
super(NAME, blockTypeOperators::getComparisonOperator, "Returns the maximum values of the argument");
super(NAME,
type -> getMaxCompare(blockTypeOperators, type),
"Returns the maximum values of the argument");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.prestosql.annotation.UsedByGeneratedCode;
import io.prestosql.metadata.FunctionArgumentDefinition;
import io.prestosql.metadata.FunctionBinding;
import io.prestosql.metadata.FunctionDependencies;
Expand All @@ -46,6 +45,7 @@
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.sql.gen.CallSiteBinder;
import io.prestosql.sql.gen.SqlTypeBytecodeExpression;
import io.prestosql.util.MinMaxCompare;

import java.lang.invoke.MethodHandle;
import java.lang.reflect.Method;
Expand Down Expand Up @@ -83,16 +83,12 @@
import static io.prestosql.util.CompilerUtils.defineClass;
import static io.prestosql.util.CompilerUtils.makeClassName;
import static io.prestosql.util.Reflection.methodHandle;
import static java.lang.invoke.MethodHandles.filterReturnValue;
import static java.util.Arrays.stream;

public abstract class AbstractMinMaxBy
extends SqlAggregationFunction
{
private static final MethodHandle MIN_FUNCTION = methodHandle(AbstractMinMaxBy.class, "min", long.class);
private static final MethodHandle MAX_FUNCTION = methodHandle(AbstractMinMaxBy.class, "max", long.class);

private final MethodHandle comparisonResultAdapter;
private final boolean min;

protected AbstractMinMaxBy(boolean min, String description)
{
Expand All @@ -115,7 +111,7 @@ protected AbstractMinMaxBy(boolean min, String description)
AGGREGATE),
true,
false);
this.comparisonResultAdapter = min ? MIN_FUNCTION : MAX_FUNCTION;
this.min = min;
}

@Override
Expand Down Expand Up @@ -182,8 +178,7 @@ private InternalAggregationFunction generateAggregation(Type valueType, Type key
List<Type> inputTypes = ImmutableList.of(valueType, keyType);

CallSiteBinder binder = new CallSiteBinder();
MethodHandle compareMethod = functionDependencies.getOperatorInvoker(COMPARISON, ImmutableList.of(keyType, keyType), Optional.empty()).getMethodHandle();
compareMethod = filterReturnValue(compareMethod, comparisonResultAdapter);
MethodHandle compareMethod = MinMaxCompare.getMinMaxCompare(functionDependencies, keyType, Optional.empty(), min);

ClassDefinition definition = new ClassDefinition(
a(PUBLIC, FINAL),
Expand Down Expand Up @@ -339,16 +334,4 @@ private static Method getMethod(Class<?> stateClass, String name)
.findFirst()
.orElseThrow(() -> new IllegalArgumentException("State class does not have a method named " + name));
}

@UsedByGeneratedCode
public static boolean min(long comparisonResult)
{
return comparisonResult < 0;
}

@UsedByGeneratedCode
public static boolean max(long comparisonResult)
{
return comparisonResult > 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@

import io.prestosql.type.BlockTypeOperators;

import static io.prestosql.util.MinMaxCompare.getMaxCompare;

public class MaxByNAggregationFunction
extends AbstractMinMaxByNAggregationFunction
{
private static final String NAME = "max_by";

public MaxByNAggregationFunction(BlockTypeOperators blockTypeOperators)
{
super(NAME, blockTypeOperators::getComparisonOperator, "Returns the values of the first argument associated with the maximum values of the second argument");
super(NAME,
type -> getMaxCompare(blockTypeOperators, type),
"Returns the values of the first argument associated with the maximum values of the second argument");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.instruction.LabelNode;
import io.prestosql.annotation.UsedByGeneratedCode;
import io.prestosql.metadata.FunctionArgumentDefinition;
import io.prestosql.metadata.FunctionBinding;
import io.prestosql.metadata.FunctionDependencies;
Expand Down Expand Up @@ -65,19 +64,16 @@
import static io.prestosql.util.CompilerUtils.defineClass;
import static io.prestosql.util.CompilerUtils.makeClassName;
import static io.prestosql.util.Failures.checkCondition;
import static io.prestosql.util.MinMaxCompare.getMinMaxCompare;
import static io.prestosql.util.Reflection.methodHandle;
import static java.lang.invoke.MethodHandles.filterReturnValue;
import static java.lang.invoke.MethodType.methodType;
import static java.util.Collections.nCopies;
import static java.util.stream.Collectors.joining;

public abstract class AbstractGreatestLeast
extends SqlScalarFunction
{
private static final MethodHandle MIN_FUNCTION = methodHandle(AbstractGreatestLeast.class, "min", long.class);
private static final MethodHandle MAX_FUNCTION = methodHandle(AbstractGreatestLeast.class, "max", long.class);

private final MethodHandle comparisonResultAdapter;
private final boolean min;

protected AbstractGreatestLeast(boolean min, String description)
{
Expand All @@ -95,7 +91,7 @@ protected AbstractGreatestLeast(boolean min, String description)
true,
description,
SCALAR));
this.comparisonResultAdapter = min ? MIN_FUNCTION : MAX_FUNCTION;
this.min = min;
}

@Override
Expand All @@ -112,8 +108,7 @@ public ScalarFunctionImplementation specialize(FunctionBinding functionBinding,
Type type = functionBinding.getTypeVariable("E");
checkArgument(type.isOrderable(), "Type must be orderable");

MethodHandle compareMethod = functionDependencies.getOperatorInvoker(COMPARISON, ImmutableList.of(type, type), Optional.empty()).getMethodHandle();
compareMethod = filterReturnValue(compareMethod, comparisonResultAdapter);
MethodHandle compareMethod = getMinMaxCompare(functionDependencies, type, Optional.empty(), min);

List<Class<?>> javaTypes = IntStream.range(0, functionBinding.getArity())
.mapToObj(i -> wrap(type.getJavaType()))
Expand Down Expand Up @@ -192,16 +187,4 @@ private Class<?> generate(List<Class<?>> javaTypes, MethodHandle compareMethod)

return defineClass(definition, Object.class, binder.getBindings(), new DynamicClassLoader(getClass().getClassLoader()));
}

@UsedByGeneratedCode
public static boolean min(long comparisonResult)
{
return comparisonResult < 0;
}

@UsedByGeneratedCode
public static boolean max(long comparisonResult)
{
return comparisonResult > 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.prestosql.spi.function.OperatorType.COMPARISON;
import static io.prestosql.spi.type.DoubleType.DOUBLE;
import static io.prestosql.spi.type.RealType.REAL;
import static io.prestosql.util.Failures.internalError;
import static java.lang.Float.intBitsToFloat;

@ScalarFunction("array_max")
@Description("Get maximum value of array")
Expand Down Expand Up @@ -126,4 +129,58 @@ private static int findMaxArrayElement(MethodHandle compareMethodHandle, Block b
throw internalError(t);
}
}

@SqlType("double")
@SqlNullable
public static Double doubleTypeArrayMax(@SqlType("array(double)") Block block)
{
if (block.getPositionCount() == 0) {
return null;
}
int selectedPosition = -1;
for (int position = 0; position < block.getPositionCount(); position++) {
if (block.isNull(position)) {
return null;
}
if (selectedPosition < 0 || doubleGreater(DOUBLE.getDouble(block, position), DOUBLE.getDouble(block, selectedPosition))) {
selectedPosition = position;
}
}
return DOUBLE.getDouble(block, selectedPosition);
}

private static boolean doubleGreater(double left, double right)
{
return (left > right) || Double.isNaN(right);
}

@SqlType("real")
@SqlNullable
public static Long realTypeArrayMax(@SqlType("array(real)") Block block)
{
if (block.getPositionCount() == 0) {
return null;
}
int selectedPosition = -1;
for (int position = 0; position < block.getPositionCount(); position++) {
if (block.isNull(position)) {
return null;
}
if (selectedPosition < 0 || floatGreater(getReal(block, position), getReal(block, selectedPosition))) {
selectedPosition = position;
}
}
return REAL.getLong(block, selectedPosition);
}

@SuppressWarnings("NumericCastThatLosesPrecision")
private static float getReal(Block block, int position)
{
return intBitsToFloat((int) REAL.getLong(block, position));
}

private static boolean floatGreater(float left, float right)
{
return (left > right) || Float.isNaN(right);
}
}
Loading