diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/AbstractMinMaxBy.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/AbstractMinMaxBy.java index c0779be3caa4..ce2832d6a630 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/AbstractMinMaxBy.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/AbstractMinMaxBy.java @@ -28,12 +28,16 @@ import io.trino.operator.aggregation.state.BlockPositionState; import io.trino.operator.aggregation.state.BlockPositionStateSerializer; import io.trino.operator.aggregation.state.NullableBooleanState; +import io.trino.operator.aggregation.state.NullableBooleanStateSerializer; import io.trino.operator.aggregation.state.NullableDoubleState; +import io.trino.operator.aggregation.state.NullableDoubleStateSerializer; import io.trino.operator.aggregation.state.NullableLongState; +import io.trino.operator.aggregation.state.NullableLongStateSerializer; import io.trino.operator.aggregation.state.NullableState; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateSerializer; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.util.MinMaxCompare; @@ -46,7 +50,6 @@ import static io.trino.metadata.Signature.orderableTypeParameter; import static io.trino.metadata.Signature.typeVariable; import static io.trino.operator.aggregation.state.StateCompiler.generateStateFactory; -import static io.trino.operator.aggregation.state.StateCompiler.generateStateSerializer; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @@ -126,14 +129,14 @@ private static AccumulatorStateDescriptor getAccumul new BlockPositionStateSerializer(type), generateStateFactory(BlockPositionState.class)); } - return getAccumulatorStateDescriptor(stateClass); + return getAccumulatorStateDescriptor(stateClass, type); } - private static AccumulatorStateDescriptor getAccumulatorStateDescriptor(Class stateClass) + private static AccumulatorStateDescriptor getAccumulatorStateDescriptor(Class stateClass, Type type) { return new AccumulatorStateDescriptor<>( stateClass, - generateStateSerializer(stateClass), + getStateSerializer(stateClass, type), generateStateFactory(stateClass)); } @@ -291,6 +294,24 @@ private static Class getStateClass(Type type) return BlockPositionState.class; } + @SuppressWarnings("unchecked") + private static AccumulatorStateSerializer getStateSerializer(Class state, Type type) + { + if (NullableLongState.class.equals(state)) { + return (AccumulatorStateSerializer) new NullableLongStateSerializer(type); + } + if (NullableDoubleState.class.equals(state)) { + return (AccumulatorStateSerializer) new NullableDoubleStateSerializer(type); + } + if (NullableBooleanState.class.equals(state)) { + return (AccumulatorStateSerializer) new NullableBooleanStateSerializer(type); + } + if (BlockPositionState.class.equals(state)) { + return (AccumulatorStateSerializer) new BlockPositionStateSerializer(type); + } + throw new IllegalArgumentException("Unsupported state class: " + state); + } + private static MethodHandle getSetStateValue(Type type, Class stateClass) throws ReflectiveOperationException { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableBooleanStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableBooleanStateSerializer.java index 70f655673a47..7d8772441603 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableBooleanStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableBooleanStateSerializer.java @@ -13,20 +13,37 @@ */ package io.trino.operator.aggregation.state; +import io.trino.annotation.UsedByGeneratedCode; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorStateSerializer; import io.trino.spi.type.Type; +import static com.google.common.base.Preconditions.checkArgument; import static io.trino.spi.type.BooleanType.BOOLEAN; +import static java.util.Objects.requireNonNull; public class NullableBooleanStateSerializer implements AccumulatorStateSerializer { + private final Type type; + + @UsedByGeneratedCode + public NullableBooleanStateSerializer() + { + this(BOOLEAN); + } + + public NullableBooleanStateSerializer(Type type) + { + this.type = requireNonNull(type, "type is null"); + checkArgument(type.getJavaType() == boolean.class, "Type must use boolean stack type: " + type); + } + @Override public Type getSerializedType() { - return BOOLEAN; + return type; } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableDoubleStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableDoubleStateSerializer.java index 62fee2cc383a..381616a6848b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableDoubleStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableDoubleStateSerializer.java @@ -13,20 +13,37 @@ */ package io.trino.operator.aggregation.state; +import io.trino.annotation.UsedByGeneratedCode; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorStateSerializer; import io.trino.spi.type.Type; +import static com.google.common.base.Preconditions.checkArgument; import static io.trino.spi.type.DoubleType.DOUBLE; +import static java.util.Objects.requireNonNull; public class NullableDoubleStateSerializer implements AccumulatorStateSerializer { + private final Type type; + + @UsedByGeneratedCode + public NullableDoubleStateSerializer() + { + this(DOUBLE); + } + + public NullableDoubleStateSerializer(Type type) + { + this.type = requireNonNull(type, "type is null"); + checkArgument(type.getJavaType() == double.class, "Type must use double stack type: " + type); + } + @Override public Type getSerializedType() { - return DOUBLE; + return type; } @Override @@ -36,7 +53,7 @@ public void serialize(NullableDoubleState state, BlockBuilder out) out.appendNull(); } else { - DOUBLE.writeDouble(out, state.getValue()); + type.writeDouble(out, state.getValue()); } } @@ -48,7 +65,7 @@ public void deserialize(Block block, int index, NullableDoubleState state) } else { state.setNull(false); - state.setValue(DOUBLE.getDouble(block, index)); + state.setValue(type.getDouble(block, index)); } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableLongStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableLongStateSerializer.java index 0ca4681ef825..ea275305fdfb 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableLongStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableLongStateSerializer.java @@ -13,20 +13,37 @@ */ package io.trino.operator.aggregation.state; +import io.trino.annotation.UsedByGeneratedCode; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorStateSerializer; import io.trino.spi.type.Type; +import static com.google.common.base.Preconditions.checkArgument; import static io.trino.spi.type.BigintType.BIGINT; +import static java.util.Objects.requireNonNull; public class NullableLongStateSerializer implements AccumulatorStateSerializer { + private final Type type; + + @UsedByGeneratedCode + public NullableLongStateSerializer() + { + this(BIGINT); + } + + public NullableLongStateSerializer(Type type) + { + this.type = requireNonNull(type, "type is null"); + checkArgument(type.getJavaType() == long.class, "Type must use long stack type: " + type); + } + @Override public Type getSerializedType() { - return BIGINT; + return type; } @Override @@ -36,7 +53,7 @@ public void serialize(NullableLongState state, BlockBuilder out) out.appendNull(); } else { - BIGINT.writeLong(out, state.getValue()); + type.writeLong(out, state.getValue()); } } @@ -48,7 +65,7 @@ public void deserialize(Block block, int index, NullableLongState state) } else { state.setNull(false); - state.setValue(BIGINT.getLong(block, index)); + state.setValue(type.getLong(block, index)); } } } diff --git a/testing/trino-tests/src/test/java/io/trino/tests/AbstractTestEngineOnlyQueries.java b/testing/trino-tests/src/test/java/io/trino/tests/AbstractTestEngineOnlyQueries.java index 4055c74b89f3..94595233e997 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/AbstractTestEngineOnlyQueries.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/AbstractTestEngineOnlyQueries.java @@ -3409,6 +3409,9 @@ public void testAverageAll() public void testMaxBy() { assertQuery("SELECT MAX_BY(orderkey, totalprice) FROM orders", "SELECT orderkey FROM orders ORDER BY totalprice DESC LIMIT 1"); + assertQuery( + "SELECT clerk, max_by(orderstatus, shippriority) FROM orders WHERE orderstatus = 'O' GROUP BY 1", + "SELECT clerk, 'O' FROM orders GROUP BY clerk"); } @Test