From 074a6c5bc14ce6783c6c525166c372f62500b289 Mon Sep 17 00:00:00 2001 From: Raunaq Morarka Date: Wed, 22 Mar 2023 12:19:59 +0530 Subject: [PATCH] Remove redundant boolean state from LongDecimalWithOverflowAndLongState --- .../DecimalAverageAggregation.java | 7 +- .../LongDecimalWithOverflowAndLongState.java | 13 ++- ...ecimalWithOverflowAndLongStateFactory.java | 110 ++++++++++++++++-- ...malWithOverflowAndLongStateSerializer.java | 5 +- .../LongDecimalWithOverflowStateFactory.java | 14 +-- .../TestDecimalAverageAggregation.java | 3 +- ...malWithOverflowAndLongStateSerializer.java | 6 +- .../java/io/trino/array/LongBigArray.java | 5 + 8 files changed, 132 insertions(+), 31 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java index eadf96c8fdde..c8dc95dcdbe6 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java @@ -61,8 +61,6 @@ public static void inputShortDecimal( { state.addLong(1); // row counter - state.setNotNull(); - long[] decimal = state.getDecimalArray(); int offset = state.getDecimalArrayOffset(); @@ -89,8 +87,6 @@ public static void inputLongDecimal( { state.addLong(1); // row counter - state.setNotNull(); - long[] decimal = state.getDecimalArray(); int offset = state.getDecimalArrayOffset(); @@ -119,7 +115,7 @@ public static void combine(@AggregationState LongDecimalWithOverflowAndLongState long[] otherDecimal = otherState.getDecimalArray(); int otherOffset = otherState.getDecimalArrayOffset(); - if (state.isNotNull()) { + if (state.getLong() > 0) { long overflow = addWithOverflow( decimal[offset], decimal[offset + 1], @@ -130,7 +126,6 @@ public static void combine(@AggregationState LongDecimalWithOverflowAndLongState state.addOverflow(overflow + otherState.getOverflow()); } else { - state.setNotNull(); decimal[offset] = otherDecimal[otherOffset]; decimal[offset + 1] = otherDecimal[otherOffset + 1]; state.setOverflow(otherState.getOverflow()); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongState.java index 16ed7b7f90dc..c6d3757ed0ea 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongState.java @@ -13,15 +13,26 @@ */ package io.trino.operator.aggregation.state; +import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateMetadata; @AccumulatorStateMetadata(stateFactoryClass = LongDecimalWithOverflowAndLongStateFactory.class, stateSerializerClass = LongDecimalWithOverflowAndLongStateSerializer.class) public interface LongDecimalWithOverflowAndLongState - extends LongDecimalWithOverflowState + extends AccumulatorState { long getLong(); void setLong(long value); void addLong(long value); + + long[] getDecimalArray(); + + int getDecimalArrayOffset(); + + long getOverflow(); + + void setOverflow(long overflow); + + void addOverflow(long overflow); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateFactory.java index 9ba2767c5ef8..6084ac031b66 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateFactory.java @@ -17,7 +17,11 @@ import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateFactory; +import javax.annotation.Nullable; + import static io.airlift.slice.SizeOf.instanceSize; +import static io.airlift.slice.SizeOf.sizeOf; +import static java.lang.System.arraycopy; public class LongDecimalWithOverflowAndLongStateFactory implements AccumulatorStateFactory @@ -35,17 +39,26 @@ public LongDecimalWithOverflowAndLongState createGroupedState() } public static class GroupedLongDecimalWithOverflowAndLongState - extends LongDecimalWithOverflowStateFactory.GroupedLongDecimalWithOverflowState + extends AbstractGroupedAccumulatorState implements LongDecimalWithOverflowAndLongState { private static final int INSTANCE_SIZE = instanceSize(GroupedLongDecimalWithOverflowAndLongState.class); private final LongBigArray longs = new LongBigArray(); + /** + * Stores 128-bit decimals as pairs of longs + */ + private final LongBigArray unscaledDecimals = new LongBigArray(); + @Nullable + private LongBigArray overflows; // lazily initialized on the first overflow @Override public void ensureCapacity(long size) { longs.ensureCapacity(size); - super.ensureCapacity(size); + unscaledDecimals.ensureCapacity(size * 2); + if (overflows != null) { + overflows.ensureCapacity(size); + } } @Override @@ -66,27 +79,80 @@ public void addLong(long value) longs.add(getGroupId(), value); } + @Override + public long[] getDecimalArray() + { + return unscaledDecimals.getSegment(getGroupId() * 2); + } + + @Override + public int getDecimalArrayOffset() + { + return unscaledDecimals.getOffset(getGroupId() * 2); + } + + @Override + public long getOverflow() + { + if (overflows == null) { + return 0; + } + return overflows.get(getGroupId()); + } + + @Override + public void setOverflow(long overflow) + { + // setOverflow(0) must overwrite any existing overflow value + if (overflow == 0 && overflows == null) { + return; + } + long groupId = getGroupId(); + if (overflows == null) { + overflows = new LongBigArray(); + overflows.ensureCapacity(longs.getCapacity()); + } + overflows.set(groupId, overflow); + } + + @Override + public void addOverflow(long overflow) + { + if (overflow != 0) { + long groupId = getGroupId(); + if (overflows == null) { + overflows = new LongBigArray(); + overflows.ensureCapacity(longs.getCapacity()); + } + overflows.add(groupId, overflow); + } + } + @Override public long getEstimatedSize() { - return INSTANCE_SIZE + longs.sizeOf() + isNotNull.sizeOf() + unscaledDecimals.sizeOf() + (overflows == null ? 0 : overflows.sizeOf()); + return INSTANCE_SIZE + longs.sizeOf() + unscaledDecimals.sizeOf() + (overflows == null ? 0 : overflows.sizeOf()); } } public static class SingleLongDecimalWithOverflowAndLongState - extends LongDecimalWithOverflowStateFactory.SingleLongDecimalWithOverflowState implements LongDecimalWithOverflowAndLongState { private static final int INSTANCE_SIZE = instanceSize(SingleLongDecimalWithOverflowAndLongState.class); + private static final int SIZE = (int) sizeOf(new long[2]); - protected long longValue; + private final long[] unscaledDecimal = new long[2]; + private long longValue; + private long overflow; public SingleLongDecimalWithOverflowAndLongState() {} // for copying - private SingleLongDecimalWithOverflowAndLongState(long longValue) + private SingleLongDecimalWithOverflowAndLongState(long[] unscaledDecimal, long longValue, long overflow) { + arraycopy(unscaledDecimal, 0, this.unscaledDecimal, 0, 2); this.longValue = longValue; + this.overflow = overflow; } @Override @@ -107,6 +173,36 @@ public void addLong(long value) longValue += value; } + @Override + public long[] getDecimalArray() + { + return unscaledDecimal; + } + + @Override + public int getDecimalArrayOffset() + { + return 0; + } + + @Override + public long getOverflow() + { + return overflow; + } + + @Override + public void setOverflow(long overflow) + { + this.overflow = overflow; + } + + @Override + public void addOverflow(long overflow) + { + this.overflow += overflow; + } + @Override public long getEstimatedSize() { @@ -116,7 +212,7 @@ public long getEstimatedSize() @Override public AccumulatorState copy() { - return new SingleLongDecimalWithOverflowAndLongState(longValue); + return new SingleLongDecimalWithOverflowAndLongState(unscaledDecimal, longValue, overflow); } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java index 7556b4a207a8..2fb2579bed1d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java @@ -34,8 +34,8 @@ public Type getSerializedType() @Override public void serialize(LongDecimalWithOverflowAndLongState state, BlockBuilder out) { - if (state.isNotNull()) { - long count = state.getLong(); + long count = state.getLong(); + if (count > 0) { long overflow = state.getOverflow(); long[] decimal = state.getDecimalArray(); int offset = state.getDecimalArrayOffset(); @@ -97,7 +97,6 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowAndLongSt decimal[offset] = high; state.setOverflow(overflow); state.setLong(count); - state.setNotNull(); } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateFactory.java index 0bee25f2f574..9bc449cec9f0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateFactory.java @@ -44,13 +44,13 @@ public static class GroupedLongDecimalWithOverflowState implements LongDecimalWithOverflowState { private static final int INSTANCE_SIZE = instanceSize(GroupedLongDecimalWithOverflowState.class); - protected final BooleanBigArray isNotNull = new BooleanBigArray(); + private final BooleanBigArray isNotNull = new BooleanBigArray(); /** * Stores 128-bit decimals as pairs of longs */ - protected final LongBigArray unscaledDecimals = new LongBigArray(); + private final LongBigArray unscaledDecimals = new LongBigArray(); @Nullable - protected LongBigArray overflows; // lazily initialized on the first overflow + private LongBigArray overflows; // lazily initialized on the first overflow @Override public void ensureCapacity(long size) @@ -134,11 +134,11 @@ public static class SingleLongDecimalWithOverflowState implements LongDecimalWithOverflowState { private static final int INSTANCE_SIZE = instanceSize(SingleLongDecimalWithOverflowState.class); - protected static final int SIZE = (int) sizeOf(new long[2]); + private static final int SIZE = (int) sizeOf(new long[2]); - protected final long[] unscaledDecimal = new long[2]; - protected boolean isNotNull; - protected long overflow; + private final long[] unscaledDecimal = new long[2]; + private boolean isNotNull; + private long overflow; public SingleLongDecimalWithOverflowState() {} diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java index f09d2a12922e..d43607064b08 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState; import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory; -import io.trino.operator.aggregation.state.LongDecimalWithOverflowState; import io.trino.spi.block.BlockBuilder; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; @@ -238,7 +237,7 @@ private static void addToState(DecimalType type, LongDecimalWithOverflowAndLongS } } - private Int128 getDecimal(LongDecimalWithOverflowState state) + private Int128 getDecimal(LongDecimalWithOverflowAndLongState state) { long[] decimal = state.getDecimalArray(); int offset = state.getDecimalArrayOffset(); diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowAndLongStateSerializer.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowAndLongStateSerializer.java index 31949aec1559..1422d41d6cdc 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowAndLongStateSerializer.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowAndLongStateSerializer.java @@ -20,8 +20,6 @@ import org.testng.annotations.Test; import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertFalse; -import static org.testng.Assert.assertTrue; public class TestLongDecimalWithOverflowAndLongStateSerializer { @@ -35,11 +33,9 @@ public void testSerde(long low, long high, long overflow, long count, int expect state.getDecimalArray()[1] = low; state.setOverflow(overflow); state.setLong(count); - state.setNotNull(); LongDecimalWithOverflowAndLongState outState = roundTrip(state, expectedLength); - assertTrue(outState.isNotNull()); assertEquals(outState.getDecimalArray()[0], high); assertEquals(outState.getDecimalArray()[1], low); assertEquals(outState.getOverflow(), overflow); @@ -54,7 +50,7 @@ public void testNullSerde() LongDecimalWithOverflowAndLongState outState = roundTrip(state, 0); - assertFalse(outState.isNotNull()); + assertEquals(outState.getLong(), 0); } private LongDecimalWithOverflowAndLongState roundTrip(LongDecimalWithOverflowAndLongState state, int expectedLength) diff --git a/lib/trino-array/src/main/java/io/trino/array/LongBigArray.java b/lib/trino-array/src/main/java/io/trino/array/LongBigArray.java index d98a9c391f60..1662de05e071 100644 --- a/lib/trino-array/src/main/java/io/trino/array/LongBigArray.java +++ b/lib/trino-array/src/main/java/io/trino/array/LongBigArray.java @@ -128,6 +128,11 @@ public void ensureCapacity(long length) grow(length); } + public long getCapacity() + { + return capacity; + } + /** * Copies this array, beginning at the specified sourceIndex, to the specified destinationIndex of * the destination array. A subsequence of this array's components are copied to the destination