diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java index 78377e0b9483..b65adbca7665 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java @@ -15,11 +15,15 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Joiner; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Multimap; +import com.google.common.util.concurrent.UncheckedExecutionException; import io.airlift.slice.Slice; import io.trino.Session; import io.trino.client.NodeVersion; @@ -201,6 +205,9 @@ public final class MetadataManager private final ResolvedFunctionDecoder functionDecoder; + private final LoadingCache operatorCache; + private final LoadingCache coercionCache; + @Inject public MetadataManager( FeaturesConfig featuresConfig, @@ -248,6 +255,23 @@ public MetadataManager( verifyTypes(); functionDecoder = new ResolvedFunctionDecoder(this::getType); + + operatorCache = CacheBuilder.newBuilder() + .maximumSize(1000) + .build(CacheLoader.from(key -> { + String name = mangleOperatorName(key.getOperatorType()); + return resolveFunction(QualifiedName.of(name), fromTypes(key.getArgumentTypes())); + })); + + coercionCache = CacheBuilder.newBuilder() + .maximumSize(1000) + .build(CacheLoader.from(key -> { + String name = mangleOperatorName(key.getOperatorType()); + Type fromType = key.getFromType(); + Type toType = key.getToType(); + Signature signature = new Signature(name, toType.getTypeSignature(), ImmutableList.of(fromType.getTypeSignature())); + return resolve(functionResolver.resolveCoercion(functions.get(QualifiedName.of(name)), signature)); + })); } public static MetadataManager createTestMetadataManager() @@ -1884,11 +1908,15 @@ public ResolvedFunction resolveOperator(OperatorType operatorType, List argumentTypes; + + private OperatorCacheKey(OperatorType operatorType, List argumentTypes) + { + this.operatorType = requireNonNull(operatorType, "operatorType is null"); + this.argumentTypes = ImmutableList.copyOf(requireNonNull(argumentTypes, "argumentTypes is null")); + } + + public OperatorType getOperatorType() + { + return operatorType; + } + + public List getArgumentTypes() + { + return argumentTypes; + } + + @Override + public int hashCode() + { + return Objects.hash(operatorType, argumentTypes); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (!(obj instanceof OperatorCacheKey)) { + return false; + } + OperatorCacheKey other = (OperatorCacheKey) obj; + return Objects.equals(this.operatorType, other.operatorType) && + Objects.equals(this.argumentTypes, other.argumentTypes); + } + } + + private static class CoercionCacheKey + { + private final OperatorType operatorType; + private final Type fromType; + private final Type toType; + + private CoercionCacheKey(OperatorType operatorType, Type fromType, Type toType) + { + this.operatorType = requireNonNull(operatorType, "operatorType is null"); + this.fromType = requireNonNull(fromType, "fromType is null"); + this.toType = requireNonNull(toType, "toType is null"); + } + + public OperatorType getOperatorType() + { + return operatorType; + } + + public Type getFromType() + { + return fromType; + } + + public Type getToType() + { + return toType; + } + + @Override + public int hashCode() + { + return Objects.hash(operatorType, fromType, toType); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (!(obj instanceof CoercionCacheKey)) { + return false; + } + CoercionCacheKey other = (CoercionCacheKey) obj; + return Objects.equals(this.operatorType, other.operatorType) && + Objects.equals(this.fromType, other.fromType) && + Objects.equals(this.toType, other.toType); + } + } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/BenchmarkPlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/BenchmarkPlanner.java index ceb417a4bdbc..9f89ddf58d18 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/BenchmarkPlanner.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/BenchmarkPlanner.java @@ -21,6 +21,7 @@ import io.trino.plugin.tpch.TpchConnectorFactory; import io.trino.testing.LocalQueryRunner; import io.trino.tpch.Customer; +import org.intellij.lang.annotations.Language; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -53,7 +54,9 @@ import static io.trino.testing.TestingSession.testSessionBuilder; import static java.lang.String.format; import static java.util.Locale.ENGLISH; +import static java.util.stream.Collectors.joining; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; @SuppressWarnings("MethodMayBeStatic") @State(Scope.Benchmark) @@ -73,6 +76,8 @@ public static class BenchmarkData private LocalQueryRunner queryRunner; private List queries; + @Language("SQL") + private String largeInQuery; private Session session; @Setup @@ -93,6 +98,11 @@ public void setup() .filter(i -> i != 15) // q15 has two queries in it .map(i -> readResource(format("/io/trino/tpch/queries/q%d.sql", i))) .collect(toImmutableList()); + + largeInQuery = "SELECT * from orders where o_orderkey in " + + IntStream.range(0, 5000) + .mapToObj(Integer::toString) + .collect(joining(", ", "(", ")")); } @TearDown @@ -125,6 +135,16 @@ public List planQueries(BenchmarkData benchmarkData) }); } + @Benchmark + public Plan planLargeInQuery(BenchmarkData benchmarkData) + { + return benchmarkData.queryRunner.inTransaction(transactionSession -> { + LogicalPlanner.Stage stage = LogicalPlanner.Stage.valueOf(benchmarkData.stage.toUpperCase(ENGLISH)); + return benchmarkData.queryRunner.createPlan( + transactionSession, benchmarkData.largeInQuery, stage, false, WarningCollector.NOOP); + }); + } + @Test public void verify() { @@ -132,6 +152,7 @@ public void verify() data.setup(); BenchmarkPlanner benchmark = new BenchmarkPlanner(); assertEquals(benchmark.planQueries(data).size(), 21); + assertNotNull(benchmark.planLargeInQuery(data)); } public static void main(String[] args)