Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import io.trino.array.LongBigArray;
import io.trino.array.ObjectBigArray;
import io.trino.array.SliceBigArray;
import io.trino.operator.aggregation.GroupedAccumulator;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorState;
Expand Down Expand Up @@ -371,7 +370,8 @@ public static <T extends AccumulatorState> AccumulatorStateFactory<T> generateSt
}
}

DynamicClassLoader classLoader = new DynamicClassLoader(clazz.getClassLoader());
// grouped aggregation state fields use engine classes, so generated class must be able to see both plugin and system classes
DynamicClassLoader classLoader = new DynamicClassLoader(clazz.getClassLoader(), StateCompiler.class.getClassLoader());
Class<? extends T> singleStateClass = generateSingleStateClass(clazz, fieldTypes, classLoader);
Class<? extends T> groupedStateClass = generateGroupedStateClass(clazz, fieldTypes, classLoader);

Expand Down Expand Up @@ -523,8 +523,7 @@ private static <T> Class<? extends T> generateGroupedStateClass(Class<T> clazz,
a(PUBLIC, FINAL),
makeClassName("Grouped" + clazz.getSimpleName()),
type(AbstractGroupedAccumulatorState.class),
type(clazz),
type(GroupedAccumulator.class));
type(clazz));

FieldDefinition instanceSize = generateInstanceSize(definition);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public void loadPlugins()
return;
}

pluginsProvider.loadPlugins(this::loadPlugin, this::createClassLoader);
pluginsProvider.loadPlugins(this::loadPlugin, PluginManager::createClassLoader);

metadataManager.verifyTypes();

Expand Down Expand Up @@ -228,9 +228,9 @@ private void installPluginInternal(Plugin plugin, Supplier<ClassLoader> duplicat
}
}

private PluginClassLoader createClassLoader(List<URL> urls)
public static PluginClassLoader createClassLoader(List<URL> urls)
{
ClassLoader parent = getClass().getClassLoader();
ClassLoader parent = PluginManager.class.getClassLoader();
return new PluginClassLoader(urls, parent, SPI_PACKAGES);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.aggregation;

import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.LongTimestamp;

import static io.trino.spi.type.BigintType.BIGINT;

public final class LongTimestampAggregation
{
private LongTimestampAggregation() {}

public static void input(LongTimestampAggregationState state, LongTimestamp value)
{
state.setValue(state.getValue() + 1);
}

public static void combine(LongTimestampAggregationState stateA, LongTimestampAggregationState stateB)
{
stateA.setValue(stateA.getValue() + stateB.getValue());
}

public static void output(LongTimestampAggregationState state, BlockBuilder blockBuilder)
{
BIGINT.writeLong(blockBuilder, state.getValue());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.aggregation;

import io.trino.spi.function.AccumulatorState;

public interface LongTimestampAggregationState
extends AccumulatorState
{
long getValue();

void setValue(long value);
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,21 @@
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.DynamicClassLoader;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.FunctionNullability;
import io.trino.operator.aggregation.TestAccumulatorCompiler.LongTimestampAggregation.State;
import io.trino.operator.aggregation.state.StateCompiler;
import io.trino.server.PluginManager;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorState;
import io.trino.spi.function.AccumulatorStateFactory;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.type.LongTimestamp;
import io.trino.spi.type.RealType;
import io.trino.spi.type.TimestampType;
import io.trino.sql.gen.IsolatedClass;
import org.testng.annotations.Test;

import java.lang.invoke.MethodHandle;
Expand All @@ -33,26 +37,54 @@
import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.TimestampType.TIMESTAMP_PICOS;
import static io.trino.util.Reflection.methodHandle;
import static org.assertj.core.api.Assertions.assertThat;

public class TestAccumulatorCompiler
{
@Test
public void testAccumulatorCompilerForTypeSpecificObjectParameter()
{
TimestampType parameterType = TimestampType.TIMESTAMP_NANOS;
assertThat(parameterType.getJavaType()).isEqualTo(LongTimestamp.class);
assertGenerateAccumulator(LongTimestampAggregation.class, LongTimestampAggregationState.class);
}

@Test
public void testAccumulatorCompilerForTypeSpecificObjectParameterSeparateClassLoader()
throws Exception
{
TimestampType parameterType = TimestampType.TIMESTAMP_NANOS;
assertThat(parameterType.getJavaType()).isEqualTo(LongTimestamp.class);

Class<State> stateInterface = State.class;
AccumulatorStateSerializer<State> stateSerializer = StateCompiler.generateStateSerializer(stateInterface);
AccumulatorStateFactory<State> stateFactory = StateCompiler.generateStateFactory(stateInterface);
ClassLoader pluginClassLoader = PluginManager.createClassLoader(ImmutableList.of());
DynamicClassLoader classLoader = new DynamicClassLoader(pluginClassLoader);
Class<? extends AccumulatorState> stateInterface = IsolatedClass.isolateClass(
classLoader,
AccumulatorState.class,
LongTimestampAggregationState.class,
LongTimestampAggregation.class);
assertThat(stateInterface.getCanonicalName()).isEqualTo(LongTimestampAggregationState.class.getCanonicalName());
assertThat(stateInterface).isNotSameAs(LongTimestampAggregationState.class);
Class<?> aggregation = classLoader.loadClass(LongTimestampAggregation.class.getCanonicalName());
assertThat(aggregation.getCanonicalName()).isEqualTo(LongTimestampAggregation.class.getCanonicalName());
assertThat(aggregation).isNotSameAs(LongTimestampAggregation.class);

assertGenerateAccumulator(aggregation, stateInterface);
}

BoundSignature signature = new BoundSignature("longTimestampAggregation", RealType.REAL, ImmutableList.of(TimestampType.TIMESTAMP_PICOS));
MethodHandle inputFunction = methodHandle(LongTimestampAggregation.class, "input", State.class, LongTimestamp.class);
private static <S extends AccumulatorState, A> void assertGenerateAccumulator(Class<A> aggregation, Class<S> stateInterface)
{
AccumulatorStateSerializer<S> stateSerializer = StateCompiler.generateStateSerializer(stateInterface);
AccumulatorStateFactory<S> stateFactory = StateCompiler.generateStateFactory(stateInterface);

BoundSignature signature = new BoundSignature("longTimestampAggregation", RealType.REAL, ImmutableList.of(TIMESTAMP_PICOS));
MethodHandle inputFunction = methodHandle(aggregation, "input", stateInterface, LongTimestamp.class);
inputFunction = normalizeInputMethod(inputFunction, signature, STATE, INPUT_CHANNEL);
MethodHandle combineFunction = methodHandle(LongTimestampAggregation.class, "combine", State.class, State.class);
MethodHandle outputFunction = methodHandle(LongTimestampAggregation.class, "output", State.class, BlockBuilder.class);
MethodHandle combineFunction = methodHandle(aggregation, "combine", stateInterface, stateInterface);
MethodHandle outputFunction = methodHandle(aggregation, "output", stateInterface, BlockBuilder.class);
AggregationMetadata metadata = new AggregationMetadata(
inputFunction,
Optional.empty(),
Expand All @@ -65,23 +97,30 @@ public void testAccumulatorCompilerForTypeSpecificObjectParameter()
FunctionNullability functionNullability = new FunctionNullability(false, ImmutableList.of(false));

// test if we can compile aggregation
assertThat(AccumulatorCompiler.generateAccumulatorFactory(signature, metadata, functionNullability, ImmutableList.of())).isNotNull();
AccumulatorFactory accumulatorFactory = AccumulatorCompiler.generateAccumulatorFactory(signature, metadata, functionNullability, ImmutableList.of());
assertThat(accumulatorFactory).isNotNull();
assertThat(AccumulatorCompiler.generateWindowAccumulatorClass(signature, metadata, functionNullability)).isNotNull();

// TODO test if aggregation actually works...
TestingAggregationFunction aggregationFunction = new TestingAggregationFunction(
ImmutableList.of(TIMESTAMP_PICOS),
ImmutableList.of(BIGINT),
BIGINT,
accumulatorFactory);
assertThat(AggregationTestUtils.aggregation(aggregationFunction, createPage(1234))).isEqualTo(1234L);
}

public static final class LongTimestampAggregation
private static Page createPage(int count)
{
private LongTimestampAggregation() {}

public interface State
extends AccumulatorState {}

public static void input(State state, LongTimestamp value) {}

public static void combine(State stateA, State stateB) {}
Block timestampSequenceBlock = createTimestampSequenceBlock(count);
return new Page(timestampSequenceBlock.getPositionCount(), timestampSequenceBlock);
}

public static void output(State state, BlockBuilder blockBuilder) {}
private static Block createTimestampSequenceBlock(int count)
{
BlockBuilder builder = TIMESTAMP_PICOS.createFixedSizeBlockBuilder(count);
for (int i = 0; i < count; i++) {
TIMESTAMP_PICOS.writeObject(builder, new LongTimestamp(i, i));
}
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.SessionTestUtils.TEST_SESSION;
import static io.trino.operator.aggregation.AccumulatorCompiler.generateAccumulatorFactory;
import static java.util.Objects.requireNonNull;

public class TestingAggregationFunction
{
Expand Down Expand Up @@ -59,6 +60,21 @@ public TestingAggregationFunction(BoundSignature signature, FunctionNullability
TEST_SESSION);
}

public TestingAggregationFunction(List<Type> parameterTypes, List<Type> intermediateTypes, Type finalType, AccumulatorFactory factory)
{
this.parameterTypes = ImmutableList.copyOf(requireNonNull(parameterTypes, "parameterTypes is null"));
requireNonNull(intermediateTypes, "intermediateTypes is null");
this.intermediateType = (intermediateTypes.size() == 1) ? getOnlyElement(intermediateTypes) : RowType.anonymous(intermediateTypes);
this.finalType = requireNonNull(finalType, "finalType is null");
this.factory = requireNonNull(factory, "factory is null");
distinctFactory = new DistinctAccumulatorFactory(
factory,
parameterTypes,
new JoinCompiler(TYPE_OPERATORS),
new BlockTypeOperators(TYPE_OPERATORS),
TEST_SESSION);
}

public int getParameterCount()
{
return parameterTypes.size();
Expand Down