diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java index ebca1121d836..97f62c0b5dec 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java @@ -35,7 +35,6 @@ import io.trino.array.LongBigArray; import io.trino.array.ObjectBigArray; import io.trino.array.SliceBigArray; -import io.trino.operator.aggregation.GroupedAccumulator; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorState; @@ -371,7 +370,8 @@ public static AccumulatorStateFactory generateSt } } - DynamicClassLoader classLoader = new DynamicClassLoader(clazz.getClassLoader()); + // grouped aggregation state fields use engine classes, so generated class must be able to see both plugin and system classes + DynamicClassLoader classLoader = new DynamicClassLoader(clazz.getClassLoader(), StateCompiler.class.getClassLoader()); Class singleStateClass = generateSingleStateClass(clazz, fieldTypes, classLoader); Class groupedStateClass = generateGroupedStateClass(clazz, fieldTypes, classLoader); @@ -523,8 +523,7 @@ private static Class generateGroupedStateClass(Class clazz, a(PUBLIC, FINAL), makeClassName("Grouped" + clazz.getSimpleName()), type(AbstractGroupedAccumulatorState.class), - type(clazz), - type(GroupedAccumulator.class)); + type(clazz)); FieldDefinition instanceSize = generateInstanceSize(definition); diff --git a/core/trino-main/src/main/java/io/trino/server/PluginManager.java b/core/trino-main/src/main/java/io/trino/server/PluginManager.java index 9391d21f6b18..c357b4c13128 100644 --- a/core/trino-main/src/main/java/io/trino/server/PluginManager.java +++ b/core/trino-main/src/main/java/io/trino/server/PluginManager.java @@ -113,7 +113,7 @@ public void loadPlugins() return; } - pluginsProvider.loadPlugins(this::loadPlugin, this::createClassLoader); + pluginsProvider.loadPlugins(this::loadPlugin, PluginManager::createClassLoader); metadataManager.verifyTypes(); @@ -228,9 +228,9 @@ private void installPluginInternal(Plugin plugin, Supplier duplicat } } - private PluginClassLoader createClassLoader(List urls) + public static PluginClassLoader createClassLoader(List urls) { - ClassLoader parent = getClass().getClassLoader(); + ClassLoader parent = PluginManager.class.getClassLoader(); return new PluginClassLoader(urls, parent, SPI_PACKAGES); } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/LongTimestampAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/LongTimestampAggregation.java new file mode 100644 index 000000000000..c1cb0f57dd62 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/LongTimestampAggregation.java @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.LongTimestamp; + +import static io.trino.spi.type.BigintType.BIGINT; + +public final class LongTimestampAggregation +{ + private LongTimestampAggregation() {} + + public static void input(LongTimestampAggregationState state, LongTimestamp value) + { + state.setValue(state.getValue() + 1); + } + + public static void combine(LongTimestampAggregationState stateA, LongTimestampAggregationState stateB) + { + stateA.setValue(stateA.getValue() + stateB.getValue()); + } + + public static void output(LongTimestampAggregationState state, BlockBuilder blockBuilder) + { + BIGINT.writeLong(blockBuilder, state.getValue()); + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/LongTimestampAggregationState.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/LongTimestampAggregationState.java new file mode 100644 index 000000000000..0c09df8645dc --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/LongTimestampAggregationState.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import io.trino.spi.function.AccumulatorState; + +public interface LongTimestampAggregationState + extends AccumulatorState +{ + long getValue(); + + void setValue(long value); +} diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java index e4f3a3becdc5..d65ea647cb52 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java @@ -14,10 +14,13 @@ package io.trino.operator.aggregation; import com.google.common.collect.ImmutableList; +import io.airlift.bytecode.DynamicClassLoader; import io.trino.metadata.BoundSignature; import io.trino.metadata.FunctionNullability; -import io.trino.operator.aggregation.TestAccumulatorCompiler.LongTimestampAggregation.State; import io.trino.operator.aggregation.state.StateCompiler; +import io.trino.server.PluginManager; +import io.trino.spi.Page; +import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateFactory; @@ -25,6 +28,7 @@ import io.trino.spi.type.LongTimestamp; import io.trino.spi.type.RealType; import io.trino.spi.type.TimestampType; +import io.trino.sql.gen.IsolatedClass; import org.testng.annotations.Test; import java.lang.invoke.MethodHandle; @@ -33,6 +37,8 @@ import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL; import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.TimestampType.TIMESTAMP_PICOS; import static io.trino.util.Reflection.methodHandle; import static org.assertj.core.api.Assertions.assertThat; @@ -40,19 +46,45 @@ public class TestAccumulatorCompiler { @Test public void testAccumulatorCompilerForTypeSpecificObjectParameter() + { + TimestampType parameterType = TimestampType.TIMESTAMP_NANOS; + assertThat(parameterType.getJavaType()).isEqualTo(LongTimestamp.class); + assertGenerateAccumulator(LongTimestampAggregation.class, LongTimestampAggregationState.class); + } + + @Test + public void testAccumulatorCompilerForTypeSpecificObjectParameterSeparateClassLoader() + throws Exception { TimestampType parameterType = TimestampType.TIMESTAMP_NANOS; assertThat(parameterType.getJavaType()).isEqualTo(LongTimestamp.class); - Class stateInterface = State.class; - AccumulatorStateSerializer stateSerializer = StateCompiler.generateStateSerializer(stateInterface); - AccumulatorStateFactory stateFactory = StateCompiler.generateStateFactory(stateInterface); + ClassLoader pluginClassLoader = PluginManager.createClassLoader(ImmutableList.of()); + DynamicClassLoader classLoader = new DynamicClassLoader(pluginClassLoader); + Class stateInterface = IsolatedClass.isolateClass( + classLoader, + AccumulatorState.class, + LongTimestampAggregationState.class, + LongTimestampAggregation.class); + assertThat(stateInterface.getCanonicalName()).isEqualTo(LongTimestampAggregationState.class.getCanonicalName()); + assertThat(stateInterface).isNotSameAs(LongTimestampAggregationState.class); + Class aggregation = classLoader.loadClass(LongTimestampAggregation.class.getCanonicalName()); + assertThat(aggregation.getCanonicalName()).isEqualTo(LongTimestampAggregation.class.getCanonicalName()); + assertThat(aggregation).isNotSameAs(LongTimestampAggregation.class); + + assertGenerateAccumulator(aggregation, stateInterface); + } - BoundSignature signature = new BoundSignature("longTimestampAggregation", RealType.REAL, ImmutableList.of(TimestampType.TIMESTAMP_PICOS)); - MethodHandle inputFunction = methodHandle(LongTimestampAggregation.class, "input", State.class, LongTimestamp.class); + private static void assertGenerateAccumulator(Class aggregation, Class stateInterface) + { + AccumulatorStateSerializer stateSerializer = StateCompiler.generateStateSerializer(stateInterface); + AccumulatorStateFactory stateFactory = StateCompiler.generateStateFactory(stateInterface); + + BoundSignature signature = new BoundSignature("longTimestampAggregation", RealType.REAL, ImmutableList.of(TIMESTAMP_PICOS)); + MethodHandle inputFunction = methodHandle(aggregation, "input", stateInterface, LongTimestamp.class); inputFunction = normalizeInputMethod(inputFunction, signature, STATE, INPUT_CHANNEL); - MethodHandle combineFunction = methodHandle(LongTimestampAggregation.class, "combine", State.class, State.class); - MethodHandle outputFunction = methodHandle(LongTimestampAggregation.class, "output", State.class, BlockBuilder.class); + MethodHandle combineFunction = methodHandle(aggregation, "combine", stateInterface, stateInterface); + MethodHandle outputFunction = methodHandle(aggregation, "output", stateInterface, BlockBuilder.class); AggregationMetadata metadata = new AggregationMetadata( inputFunction, Optional.empty(), @@ -65,23 +97,30 @@ public void testAccumulatorCompilerForTypeSpecificObjectParameter() FunctionNullability functionNullability = new FunctionNullability(false, ImmutableList.of(false)); // test if we can compile aggregation - assertThat(AccumulatorCompiler.generateAccumulatorFactory(signature, metadata, functionNullability, ImmutableList.of())).isNotNull(); + AccumulatorFactory accumulatorFactory = AccumulatorCompiler.generateAccumulatorFactory(signature, metadata, functionNullability, ImmutableList.of()); + assertThat(accumulatorFactory).isNotNull(); assertThat(AccumulatorCompiler.generateWindowAccumulatorClass(signature, metadata, functionNullability)).isNotNull(); - // TODO test if aggregation actually works... + TestingAggregationFunction aggregationFunction = new TestingAggregationFunction( + ImmutableList.of(TIMESTAMP_PICOS), + ImmutableList.of(BIGINT), + BIGINT, + accumulatorFactory); + assertThat(AggregationTestUtils.aggregation(aggregationFunction, createPage(1234))).isEqualTo(1234L); } - public static final class LongTimestampAggregation + private static Page createPage(int count) { - private LongTimestampAggregation() {} - - public interface State - extends AccumulatorState {} - - public static void input(State state, LongTimestamp value) {} - - public static void combine(State stateA, State stateB) {} + Block timestampSequenceBlock = createTimestampSequenceBlock(count); + return new Page(timestampSequenceBlock.getPositionCount(), timestampSequenceBlock); + } - public static void output(State state, BlockBuilder blockBuilder) {} + private static Block createTimestampSequenceBlock(int count) + { + BlockBuilder builder = TIMESTAMP_PICOS.createFixedSizeBlockBuilder(count); + for (int i = 0; i < count; i++) { + TIMESTAMP_PICOS.writeObject(builder, new LongTimestamp(i, i)); + } + return builder.build(); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java index c12115d60870..532c59584046 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java @@ -30,6 +30,7 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.operator.aggregation.AccumulatorCompiler.generateAccumulatorFactory; +import static java.util.Objects.requireNonNull; public class TestingAggregationFunction { @@ -59,6 +60,21 @@ public TestingAggregationFunction(BoundSignature signature, FunctionNullability TEST_SESSION); } + public TestingAggregationFunction(List parameterTypes, List intermediateTypes, Type finalType, AccumulatorFactory factory) + { + this.parameterTypes = ImmutableList.copyOf(requireNonNull(parameterTypes, "parameterTypes is null")); + requireNonNull(intermediateTypes, "intermediateTypes is null"); + this.intermediateType = (intermediateTypes.size() == 1) ? getOnlyElement(intermediateTypes) : RowType.anonymous(intermediateTypes); + this.finalType = requireNonNull(finalType, "finalType is null"); + this.factory = requireNonNull(factory, "factory is null"); + distinctFactory = new DistinctAccumulatorFactory( + factory, + parameterTypes, + new JoinCompiler(TYPE_OPERATORS), + new BlockTypeOperators(TYPE_OPERATORS), + TEST_SESSION); + } + public int getParameterCount() { return parameterTypes.size();