Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ public final class SystemSessionProperties
public static final String ADAPTIVE_PARTIAL_AGGREGATION_ENABLED = "adaptive_partial_aggregation_enabled";
public static final String ADAPTIVE_PARTIAL_AGGREGATION_MIN_ROWS = "adaptive_partial_aggregation_min_rows";
public static final String ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD = "adaptive_partial_aggregation_unique_rows_ratio_threshold";
public static final String JOIN_PARTITIONED_BUILD_MIN_ROW_COUNT = "join_partitioned_build_min_row_count";

private final List<PropertyMetadata<?>> sessionProperties;

Expand Down Expand Up @@ -823,6 +824,12 @@ public SystemSessionProperties(
ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD,
"Ratio between aggregation output and input rows above which partial aggregation might be adaptively turned off",
optimizerConfig.getAdaptivePartialAggregationUniqueRowsRatioThreshold(),
false),
longProperty(
JOIN_PARTITIONED_BUILD_MIN_ROW_COUNT,
"Minimum number of join build side rows required to use partitioned join lookup",
optimizerConfig.getJoinPartitionedBuildMinRowCount(),
value -> validateNonNegativeLongValue(value, JOIN_PARTITIONED_BUILD_MIN_ROW_COUNT),
false));
}

Expand Down Expand Up @@ -1204,6 +1211,13 @@ private static Integer validateIntegerValue(Object value, String property, int l
return intValue;
}

private static void validateNonNegativeLongValue(Long value, String property)
{
if (value < 0) {
throw new TrinoException(INVALID_SESSION_PROPERTY, format("%s must be equal or greater than 0", property));
}
}

private static double validateDoubleRange(Object value, String property, double lowerBoundIncluded, double upperBoundIncluded)
{
double doubleValue = (double) value;
Expand Down Expand Up @@ -1479,4 +1493,9 @@ public static double getAdaptivePartialAggregationUniqueRowsRatioThreshold(Sessi
{
return session.getSystemProperty(ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD, Double.class);
}

public static long getJoinPartitionedBuildMinRowCount(Session session)
{
return session.getSystemProperty(JOIN_PARTITIONED_BUILD_MIN_ROW_COUNT, Long.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

import com.google.common.base.Stopwatch;
import com.google.common.base.Ticker;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Ordering;
import com.google.common.collect.SetMultimap;
import com.google.common.collect.Streams;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
Expand Down Expand Up @@ -47,7 +48,6 @@
import java.time.Duration;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -399,8 +399,8 @@ public void release()
node.cancel(true);
if (node.isDone() && !node.isCancelled()) {
deallocateMemory(getFutureValue(node));
wakeupProcessPendingAcquires();
checkState(fulfilledAcquires.remove(this), "node lease %s not found in fulfilledAcquires %s", this, fulfilledAcquires);
wakeupProcessPendingAcquires();
}
}
else {
Expand Down Expand Up @@ -456,16 +456,10 @@ public BinPackingSimulation(
realtimeTasksMemoryPerNode.put(node.getNodeIdentifier(), memoryPoolInfo.getTaskMemoryReservations());
}

Map<String, Set<BinPackingNodeLease>> fulfilledAcquiresByNode = new HashMap<>();
SetMultimap<String, BinPackingNodeLease> fulfilledAcquiresByNode = HashMultimap.create();
for (BinPackingNodeLease fulfilledAcquire : fulfilledAcquires) {
InternalNode node = fulfilledAcquire.getAssignedNode();
fulfilledAcquiresByNode.compute(node.getNodeIdentifier(), (key, set) -> {
if (set == null) {
set = new HashSet<>();
}
set.add(fulfilledAcquire);
return set;
});
fulfilledAcquiresByNode.put(node.getNodeIdentifier(), fulfilledAcquire);
}

nodesRemainingMemory = new HashMap<>();
Expand All @@ -488,7 +482,7 @@ public BinPackingSimulation(
}

Map<String, Long> realtimeNodeMemory = realtimeTasksMemoryPerNode.get(node.getNodeIdentifier());
Set<BinPackingNodeLease> nodeFulfilledAcquires = fulfilledAcquiresByNode.getOrDefault(node.getNodeIdentifier(), ImmutableSet.of());
Set<BinPackingNodeLease> nodeFulfilledAcquires = fulfilledAcquiresByNode.get(node.getNodeIdentifier());

long nodeUsedMemoryRuntimeAdjusted = 0;
for (BinPackingNodeLease lease : nodeFulfilledAcquires) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,35 @@
import io.trino.metadata.Signature;
import io.trino.operator.ParametricImplementationsGroup;
import io.trino.operator.annotations.FunctionsParserHelper;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorState;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.FunctionDependency;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.LiteralParameter;
import io.trino.spi.function.OperatorDependency;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.RemoveInputFunction;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.TypeSignature;

import java.lang.annotation.Annotation;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Strings.emptyToNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.operator.aggregation.AggregationImplementation.Parser.parseImplementation;
import static io.trino.operator.annotations.FunctionsParserHelper.parseDescription;
import static java.util.Collections.nCopies;
import static java.util.Objects.requireNonNull;

public final class AggregationFromAnnotationsParser
Expand Down Expand Up @@ -174,38 +182,70 @@ private static List<String> getAliases(AggregationFunction aggregationAnnotation

private static Optional<Method> getCombineFunction(Class<?> clazz, Class<?> stateClass)
{
// Only include methods that match this state class
List<Method> combineFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, CombineFunction.class).stream()
.filter(method -> method.getParameterTypes()[AggregationImplementation.Parser.findAggregationStateParamId(method, 0)] == stateClass)
.filter(method -> method.getParameterTypes()[AggregationImplementation.Parser.findAggregationStateParamId(method, 1)] == stateClass)
.collect(toImmutableList());

List<Method> combineFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, CombineFunction.class);
for (Method combineFunction : combineFunctions) {
// verify parameter types
List<Class<?>> parameterTypes = getNonDependencyParameterTypes(combineFunction);
List<Class<?>> expectedParameterTypes = nCopies(2, stateClass);
checkArgument(parameterTypes.equals(expectedParameterTypes), "Expected combine function non-dependency parameters to be %s: %s", expectedParameterTypes, combineFunction);
}
checkArgument(combineFunctions.size() <= 1, "There must be only one @CombineFunction in class %s for the @AggregationState %s", clazz.toGenericString(), stateClass.toGenericString());
return combineFunctions.stream().findFirst();
}

private static List<Method> getOutputFunctions(Class<?> clazz, Class<?> stateClass)
{
// Only include methods that match this state class
List<Method> outputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, OutputFunction.class).stream()
.filter(method -> method.getParameterTypes()[AggregationImplementation.Parser.findAggregationStateParamId(method)] == stateClass)
.collect(toImmutableList());

List<Method> outputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, OutputFunction.class);
for (Method outputFunction : outputFunctions) {
// verify parameter types
List<Class<?>> parameterTypes = getNonDependencyParameterTypes(outputFunction);
List<Class<?>> expectedParameterTypes = ImmutableList.<Class<?>>builder()
.add(stateClass)
.add(BlockBuilder.class)
.build();
checkArgument(parameterTypes.equals(expectedParameterTypes),
"Expected output function non-dependency parameters to be %s: %s",
expectedParameterTypes.stream().map(Class::getSimpleName).collect(toImmutableList()),
outputFunction);
}
checkArgument(!outputFunctions.isEmpty(), "Aggregation has no output functions");
return outputFunctions;
}

private static List<Method> getInputFunctions(Class<?> clazz, Class<?> stateClass)
{
// Only include methods that match this state class
List<Method> inputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class).stream()
.filter(method -> (method.getParameterTypes()[AggregationImplementation.Parser.findAggregationStateParamId(method)] == stateClass))
.collect(toImmutableList());
List<Method> inputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class);
for (Method inputFunction : inputFunctions) {
// verify state parameter is first non-dependency parameter
Class<?> actualStateType = getNonDependencyParameterTypes(inputFunction).get(0);
checkArgument(stateClass.equals(actualStateType),
"Expected input function non-dependency parameters to begin with state type %s: %s",
stateClass.getSimpleName(),
inputFunction);
}

checkArgument(!inputFunctions.isEmpty(), "Aggregation has no input functions");
return inputFunctions;
}

private static IntStream getNonDependencyParameters(Method function)
{
Annotation[][] parameterAnnotations = function.getParameterAnnotations();
return IntStream.range(0, function.getParameterCount())
.filter(i -> Arrays.stream(parameterAnnotations[i]).noneMatch(TypeParameter.class::isInstance))
.filter(i -> Arrays.stream(parameterAnnotations[i]).noneMatch(LiteralParameter.class::isInstance))
.filter(i -> Arrays.stream(parameterAnnotations[i]).noneMatch(OperatorDependency.class::isInstance))
.filter(i -> Arrays.stream(parameterAnnotations[i]).noneMatch(FunctionDependency.class::isInstance));
}

private static List<Class<?>> getNonDependencyParameterTypes(Method function)
{
Class<?>[] parameterTypes = function.getParameterTypes();
return getNonDependencyParameters(function)
.mapToObj(index -> parameterTypes[index])
.collect(toImmutableList());
}

private static Optional<Method> getRemoveInputFunction(Class<?> clazz, Method inputFunction)
{
// Only include methods which take the same parameters as the corresponding input function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public class OptimizerConfig
private boolean adaptivePartialAggregationEnabled = true;
private long adaptivePartialAggregationMinRows = 100_000;
private double adaptivePartialAggregationUniqueRowsRatioThreshold = 0.8;
private long joinPartitionedBuildMinRowCount = 1_000_000L;

public enum JoinReorderingStrategy
{
Expand Down Expand Up @@ -713,4 +714,18 @@ public OptimizerConfig setAdaptivePartialAggregationUniqueRowsRatioThreshold(dou
this.adaptivePartialAggregationUniqueRowsRatioThreshold = adaptivePartialAggregationUniqueRowsRatioThreshold;
return this;
}

@Min(0)
public long getJoinPartitionedBuildMinRowCount()
{
return joinPartitionedBuildMinRowCount;
}

@Config("optimizer.join-partitioned-build-min-row-count")
@ConfigDescription("Minimum number of join build side rows required to use partitioned join lookup")
public OptimizerConfig setJoinPartitionedBuildMinRowCount(long joinPartitionedBuildMinRowCount)
{
this.joinPartitionedBuildMinRowCount = joinPartitionedBuildMinRowCount;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@
import io.trino.sql.planner.iterative.rule.UnwrapRowSubscript;
import io.trino.sql.planner.iterative.rule.UnwrapSingleColumnRowInApply;
import io.trino.sql.planner.iterative.rule.UnwrapTimestampToDateCastInComparison;
import io.trino.sql.planner.iterative.rule.UseNonPartitionedJoinLookupSource;
import io.trino.sql.planner.optimizations.AddExchanges;
import io.trino.sql.planner.optimizations.AddLocalExchanges;
import io.trino.sql.planner.optimizations.BeginTableWrite;
Expand Down Expand Up @@ -919,6 +920,13 @@ public PlanOptimizers(

// Optimizers above this don't understand local exchanges, so be careful moving this.
builder.add(new AddLocalExchanges(plannerContext, typeAnalyzer));
// UseNonPartitionedJoinLookupSource needs to run after AddLocalExchanges since it operates on ExchangeNodes added by this optimizer.
builder.add(new IterativeOptimizer(
plannerContext,
ruleStats,
statsCalculator,
costCalculator,
ImmutableSet.of(new UseNonPartitionedJoinLookupSource())));

// Optimizers above this do not need to care about aggregations with the type other than SINGLE
// This optimizer must be run after all exchange-related optimizers
Expand Down
Loading