diff --git a/presto-benchmark/src/test/java/com/facebook/presto/benchmark/BenchmarkDecimalAggregation.java b/presto-benchmark/src/test/java/com/facebook/presto/benchmark/BenchmarkDecimalAggregation.java index 082357eb2ebab..973e3640dd990 100644 --- a/presto-benchmark/src/test/java/com/facebook/presto/benchmark/BenchmarkDecimalAggregation.java +++ b/presto-benchmark/src/test/java/com/facebook/presto/benchmark/BenchmarkDecimalAggregation.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.benchmark; +import com.facebook.presto.spi.Page; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -28,6 +29,8 @@ import org.openjdk.jmh.runner.options.OptionsBuilder; import org.openjdk.jmh.runner.options.VerboseMode; +import java.util.List; + import static java.lang.String.format; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.openjdk.jmh.annotations.Mode.AverageTime; @@ -55,6 +58,11 @@ public static class AggregationContext private final MemoryLocalQueryRunner queryRunner = new MemoryLocalQueryRunner(); + public final MemoryLocalQueryRunner getQueryRunner() + { + return queryRunner; + } + @Setup public void setUp() { @@ -62,17 +70,13 @@ public void setUp() "CREATE TABLE memory.default.orders AS SELECT orderstatus, cast(totalprice as %s) totalprice FROM tpch.sf1.orders", type)); } - - public void run() - { - queryRunner.execute(format("SELECT %s FROM orders GROUP BY orderstatus", project)); - } } @Benchmark - public void benchmarkBuildHash(AggregationContext context) + public List benchmarkBuildHash(AggregationContext context) { - context.run(); + return context.getQueryRunner() + .execute(String.format("SELECT %s FROM orders GROUP BY orderstatus", context.project)); } public static void main(String[] args) diff --git a/presto-benchmark/src/test/java/com/facebook/presto/benchmark/BenchmarkInequalityJoin.java b/presto-benchmark/src/test/java/com/facebook/presto/benchmark/BenchmarkInequalityJoin.java new file mode 100644 index 0000000000000..e6136949498f8 --- /dev/null +++ b/presto-benchmark/src/test/java/com/facebook/presto/benchmark/BenchmarkInequalityJoin.java @@ -0,0 +1,120 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.benchmark; + +import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.spi.Page; +import com.google.common.collect.ImmutableMap; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; + +import java.util.List; + +import static java.lang.String.format; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.openjdk.jmh.annotations.Mode.AverageTime; +import static org.openjdk.jmh.annotations.Scope.Thread; + +/** + * This benchmark a case when there is almost like a cross join query + * but with very selective inequality join condition. In other words + * for each probe position there are lots of matching build positions + * which are filtered out by filtering function. + */ +@SuppressWarnings("MethodMayBeStatic") +@State(Thread) +@OutputTimeUnit(MILLISECONDS) +@BenchmarkMode(AverageTime) +@Fork(3) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +public class BenchmarkInequalityJoin +{ + @State(Thread) + public static class Context + { + private MemoryLocalQueryRunner queryRunner; + + @Param({"true", "false"}) + private String fastInequalityJoin; + + // number of buckets. The smaller number of buckets, the longer position links chain + @Param({"100", "1000", "10000", "60000"}) + private int buckets; + + // How many positions out of 1000 will be actually joined + // 10 means 1 - 10/1000 = 99/100 positions will be filtered out + @Param({"10"}) + private int filterOutCoefficient; + + public MemoryLocalQueryRunner getQueryRunner() + { + return queryRunner; + } + + @Setup + public void setUp() + { + queryRunner = new MemoryLocalQueryRunner(ImmutableMap.of(SystemSessionProperties.FAST_INEQUALITY_JOIN, fastInequalityJoin)); + + // t1.val1 is in range [0, 1000) + // t1.bucket is in [0, 1000) + queryRunner.execute(format( + "CREATE TABLE memory.default.t1 AS SELECT " + + "orderkey %% %d bucket, " + + "(orderkey * 13) %% 1000 val1 " + + "FROM tpch.tiny.lineitem", + buckets)); + // t2.val2 is in range [0, 10) + // t2.bucket is in [0, 1000) + queryRunner.execute(format( + "CREATE TABLE memory.default.t2 AS SELECT " + + "orderkey %% %d bucket, " + + "(orderkey * 379) %% %d val2 " + + "FROM tpch.tiny.lineitem", + buckets, + filterOutCoefficient)); + } + } + + @Benchmark + public List benchmarkJoin(Context context) + { + return context.getQueryRunner() + .execute("SELECT count(*) FROM t1 JOIN t2 on (t1.bucket = t2.bucket) WHERE t1.val1 < t2.val2"); + } + + public static void main(String[] args) + throws RunnerException + { + Options options = new OptionsBuilder() + .verbosity(VerboseMode.NORMAL) + .include(".*" + BenchmarkInequalityJoin.class.getSimpleName() + ".*") + .build(); + + new Runner(options).run(); + } +} diff --git a/presto-benchmark/src/test/java/com/facebook/presto/benchmark/MemoryLocalQueryRunner.java b/presto-benchmark/src/test/java/com/facebook/presto/benchmark/MemoryLocalQueryRunner.java index 60e07ed2e8971..e83fe8357389b 100644 --- a/presto-benchmark/src/test/java/com/facebook/presto/benchmark/MemoryLocalQueryRunner.java +++ b/presto-benchmark/src/test/java/com/facebook/presto/benchmark/MemoryLocalQueryRunner.java @@ -21,45 +21,65 @@ import com.facebook.presto.operator.Driver; import com.facebook.presto.operator.TaskContext; import com.facebook.presto.plugin.memory.MemoryConnectorFactory; +import com.facebook.presto.spi.Page; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.memory.MemoryPoolId; import com.facebook.presto.spiller.SpillSpaceTracker; import com.facebook.presto.testing.LocalQueryRunner; -import com.facebook.presto.testing.NullOutputOperator; +import com.facebook.presto.testing.PageConsumerOperator; import com.facebook.presto.tpch.TpchConnectorFactory; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.units.DataSize; import org.intellij.lang.annotations.Language; import java.util.List; +import java.util.Map; import java.util.concurrent.ExecutorService; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; public class MemoryLocalQueryRunner { - protected LocalQueryRunner localQueryRunner = createMemoryLocalQueryRunner(); + protected final LocalQueryRunner localQueryRunner; + protected final Session session; - public void execute(@Language("SQL") String query) + public MemoryLocalQueryRunner() + { + this(ImmutableMap.of()); + } + + public MemoryLocalQueryRunner(Map properties) + { + Session.SessionBuilder sessionBuilder = testSessionBuilder() + .setCatalog("memory") + .setSchema("default"); + properties.forEach(sessionBuilder::setSystemProperty); + + session = sessionBuilder.build(); + localQueryRunner = createMemoryLocalQueryRunner(session); + } + + public List execute(@Language("SQL") String query) { - Session session = testSessionBuilder() - .setSystemProperty("optimizer.optimize-hash-generation", "true") - .build(); ExecutorService executor = localQueryRunner.getExecutor(); - MemoryPool memoryPool = new MemoryPool(new MemoryPoolId("test"), new DataSize(1, GIGABYTE)); - MemoryPool systemMemoryPool = new MemoryPool(new MemoryPoolId("testSystem"), new DataSize(1, GIGABYTE)); - SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(new DataSize(1, GIGABYTE)); + MemoryPool memoryPool = new MemoryPool(new MemoryPoolId("test"), new DataSize(2, GIGABYTE)); + MemoryPool systemMemoryPool = new MemoryPool(new MemoryPoolId("testSystem"), new DataSize(2, GIGABYTE)); - TaskContext taskContext = new QueryContext(new QueryId("test"), new DataSize(256, MEGABYTE), memoryPool, systemMemoryPool, executor, new DataSize(1, GIGABYTE), spillSpaceTracker) + SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(new DataSize(1, GIGABYTE)); + TaskContext taskContext = new QueryContext(new QueryId("test"), new DataSize(1, GIGABYTE), memoryPool, systemMemoryPool, executor, new DataSize(4, GIGABYTE), spillSpaceTracker) .addTaskContext(new TaskStateMachine(new TaskId("query", 0, 0), executor), session, false, false); // Use NullOutputFactory to avoid coping out results to avoid affecting benchmark results - List drivers = localQueryRunner.createDrivers(query, new NullOutputOperator.NullOutputFactory(), taskContext); + ImmutableList.Builder output = ImmutableList.builder(); + List drivers = localQueryRunner.createDrivers( + query, + new PageConsumerOperator.PageConsumerOutputFactory(types -> output::add), + taskContext); boolean done = false; while (!done) { @@ -72,20 +92,20 @@ public void execute(@Language("SQL") String query) } done = !processed; } + + return output.build(); } - private static LocalQueryRunner createMemoryLocalQueryRunner() + private static LocalQueryRunner createMemoryLocalQueryRunner(Session session) { - Session.SessionBuilder sessionBuilder = testSessionBuilder() - .setCatalog("memory") - .setSchema("default"); - - Session session = sessionBuilder.build(); LocalQueryRunner localQueryRunner = LocalQueryRunner.queryRunnerWithInitialTransaction(session); // add tpch - localQueryRunner.createCatalog("tpch", new TpchConnectorFactory(1), ImmutableMap.of()); - localQueryRunner.createCatalog("memory", new MemoryConnectorFactory(), ImmutableMap.of()); + localQueryRunner.createCatalog("tpch", new TpchConnectorFactory(1), ImmutableMap.of()); + localQueryRunner.createCatalog( + "memory", + new MemoryConnectorFactory(), + ImmutableMap.of("memory.max-data-per-node", "4GB")); return localQueryRunner; } diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraDistributed.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraDistributed.java index 410217adb27f5..572a2018df1fe 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraDistributed.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraDistributed.java @@ -39,6 +39,12 @@ protected boolean supportsViews() return false; } + @Override + public void testJoinWithLessThanOnDatesInJoinClause() + { + // Cassandra does not support DATE + } + @Override public void testGroupingSetMixedExpressionAndColumn() { diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 45f134c37375b..b4c896c67dd38 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -62,6 +62,7 @@ public final class SystemSessionProperties public static final String INITIAL_SPLITS_PER_NODE = "initial_splits_per_node"; public static final String SPLIT_CONCURRENCY_ADJUSTMENT_INTERVAL = "split_concurrency_adjustment_interval"; public static final String OPTIMIZE_METADATA_QUERIES = "optimize_metadata_queries"; + public static final String FAST_INEQUALITY_JOIN = "fast_inequality_join"; public static final String QUERY_PRIORITY = "query_priority"; public static final String SPILL_ENABLED = "spill_enabled"; public static final String OPERATOR_MEMORY_LIMIT_BEFORE_SPILL = "operator_memory_limit_before_spill"; @@ -241,6 +242,11 @@ public SystemSessionProperties( "Experimental: Reorder joins to optimize plan", featuresConfig.isJoinReorderingEnabled(), false), + booleanSessionProperty( + FAST_INEQUALITY_JOIN, + "Experimental: Use faster handling of inequality join if it is possible", + featuresConfig.isFastInequalityJoins(), + false), booleanSessionProperty( COLOCATED_JOIN, "Experimental: Use a colocated join when possible", @@ -398,6 +404,11 @@ public static boolean planWithTableNodePartitioning(Session session) return session.getSystemProperty(PLAN_WITH_TABLE_NODE_PARTITIONING, Boolean.class); } + public static boolean isFastInequalityJoin(Session session) + { + return session.getSystemProperty(FAST_INEQUALITY_JOIN, Boolean.class); + } + public static boolean isJoinReorderingEnabled(Session session) { return session.getSystemProperty(REORDER_JOINS, Boolean.class); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/ArrayPositionLinks.java b/presto-main/src/main/java/com/facebook/presto/operator/ArrayPositionLinks.java new file mode 100644 index 0000000000000..760300e38a9d9 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/ArrayPositionLinks.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.spi.Page; + +import java.util.Arrays; +import java.util.Optional; +import java.util.function.Function; + +import static io.airlift.slice.SizeOf.sizeOf; +import static java.util.Objects.requireNonNull; + +public final class ArrayPositionLinks + implements PositionLinks +{ + public static class Builder implements PositionLinks.Builder + { + private final int[] positionLinks; + + private Builder(int size) + { + positionLinks = new int[size]; + Arrays.fill(positionLinks, -1); + } + + @Override + public int link(int left, int right) + { + positionLinks[left] = right; + return left; + } + + @Override + public Function, PositionLinks> build() + { + return filterFunction -> new ArrayPositionLinks(positionLinks); + } + } + + private final int[] positionLinks; + + private ArrayPositionLinks(int[] positionLinks) + { + this.positionLinks = requireNonNull(positionLinks, "positionLinks is null"); + } + + public static Builder builder(int size) + { + return new Builder(size); + } + + @Override + public int start(int position, int probePosition, Page allProbeChannelsPage) + { + return position; + } + + @Override + public int next(int position, int probePosition, Page allProbeChannelsPage) + { + return positionLinks[position]; + } + + @Override + public long getSizeInBytes() + { + return sizeOf(positionLinks); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/JoinFilterFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/JoinFilterFunction.java index ad71a164f4f7b..bfa2f0230b736 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/JoinFilterFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/JoinFilterFunction.java @@ -15,7 +15,14 @@ import com.facebook.presto.spi.Page; +import javax.annotation.concurrent.NotThreadSafe; + +import java.util.Optional; + +@NotThreadSafe public interface JoinFilterFunction { boolean filter(int leftAddress, int rightPosition, Page rightPage); + + Optional getSortChannel(); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/JoinHash.java b/presto-main/src/main/java/com/facebook/presto/operator/JoinHash.java index a5c81595064e1..0537bfb6beefd 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/JoinHash.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/JoinHash.java @@ -34,10 +34,13 @@ public final class JoinHash @Nullable private final JoinFilterFunction filterFunction; - public JoinHash(PagesHash pagesHash, Optional filterFunction) + private final PositionLinks positionLinks; + + public JoinHash(PagesHash pagesHash, Optional filterFunction, PositionLinks positionLinks) { this.pagesHash = requireNonNull(pagesHash, "pagesHash is null"); this.filterFunction = requireNonNull(filterFunction, "filterFunction can not be null").orElse(null); + this.positionLinks = requireNonNull(positionLinks, "positionLinks is null"); } @Override @@ -55,25 +58,35 @@ public int getJoinPositionCount() @Override public long getInMemorySizeInBytes() { - return pagesHash.getInMemorySizeInBytes(); + return pagesHash.getInMemorySizeInBytes() + positionLinks.getSizeInBytes(); } @Override public long getJoinPosition(int position, Page hashChannelsPage, Page allChannelsPage) { - return pagesHash.getAddressIndex(position, hashChannelsPage, allChannelsPage); + int addressIndex = pagesHash.getAddressIndex(position, hashChannelsPage, allChannelsPage); + return startJoinPosition(addressIndex, position, allChannelsPage); } @Override public long getJoinPosition(int position, Page hashChannelsPage, Page allChannelsPage, long rawHash) { - return pagesHash.getAddressIndex(position, hashChannelsPage, allChannelsPage, rawHash); + int addressIndex = pagesHash.getAddressIndex(position, hashChannelsPage, allChannelsPage, rawHash); + return startJoinPosition(addressIndex, position, allChannelsPage); + } + + private long startJoinPosition(int currentJoinPosition, int probePosition, Page allProbeChannelsPage) + { + if (currentJoinPosition == -1) { + return -1; + } + return positionLinks.start(currentJoinPosition, probePosition, allProbeChannelsPage); } @Override public final long getNextJoinPosition(long currentJoinPosition, int probePosition, Page allProbeChannelsPage) { - return pagesHash.getNextAddressIndex(toIntExact(currentJoinPosition)); + return positionLinks.next(toIntExact(currentJoinPosition), probePosition, allProbeChannelsPage); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/operator/JoinHashSupplier.java b/presto-main/src/main/java/com/facebook/presto/operator/JoinHashSupplier.java index 575bee6fbd2c0..c8b322a610797 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/JoinHashSupplier.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/JoinHashSupplier.java @@ -13,43 +13,58 @@ */ package com.facebook.presto.operator; -import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.Session; import com.facebook.presto.spi.block.Block; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory; +import it.unimi.dsi.fastutil.ints.IntComparator; import it.unimi.dsi.fastutil.longs.LongArrayList; import java.util.List; import java.util.Optional; +import java.util.function.Function; +import static com.facebook.presto.SystemSessionProperties.isFastInequalityJoin; +import static com.facebook.presto.operator.SyntheticAddress.decodePosition; +import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex; import static java.util.Objects.requireNonNull; public class JoinHashSupplier implements LookupSourceSupplier { - private final ConnectorSession session; + private final Session session; private final PagesHash pagesHash; private final LongArrayList addresses; private final List> channels; + private final Function, PositionLinks> positionLinks; private final Optional filterFunctionFactory; public JoinHashSupplier( - ConnectorSession session, + Session session, PagesHashStrategy pagesHashStrategy, LongArrayList addresses, List> channels, Optional filterFunctionFactory) { - requireNonNull(session, "session is null"); + this.session = requireNonNull(session, "session is null"); + this.addresses = requireNonNull(addresses, "addresses is null"); + this.channels = requireNonNull(channels, "channels is null"); + this.filterFunctionFactory = requireNonNull(filterFunctionFactory, "filterFunctionFactory is null"); requireNonNull(pagesHashStrategy, "pagesHashStrategy is null"); - requireNonNull(addresses, "addresses is null"); - requireNonNull(channels, "channels is null"); - requireNonNull(filterFunctionFactory, "filterFunctionFactory is null"); - - this.session = session; - this.pagesHash = new PagesHash(addresses, pagesHashStrategy); - this.addresses = addresses; - this.channels = channels; - this.filterFunctionFactory = filterFunctionFactory; + + PositionLinks.Builder positionLinksBuilder; + if (filterFunctionFactory.isPresent() && + filterFunctionFactory.get().getSortChannel().isPresent() && + isFastInequalityJoin(session)) { + positionLinksBuilder = SortedPositionLinks.builder( + addresses.size(), + new PositionComparator(pagesHashStrategy, addresses)); + } + else { + positionLinksBuilder = ArrayPositionLinks.builder(addresses.size()); + } + + this.pagesHash = new PagesHash(addresses, pagesHashStrategy, positionLinksBuilder); + this.positionLinks = positionLinksBuilder.build(); } @Override @@ -67,7 +82,46 @@ public double getExpectedHashCollisions() @Override public JoinHash get() { - Optional filterFunction = filterFunctionFactory.map(factory -> factory.create(session, addresses, channels)); - return new JoinHash(pagesHash, filterFunction); + // We need to create new JoinFilterFunction per each thread using it, since those functions + // are not thread safe... + Optional filterFunction = + filterFunctionFactory.map(factory -> factory.create(session.toConnectorSession(), addresses, channels)); + return new JoinHash( + pagesHash, + filterFunction, + positionLinks.apply(filterFunction)); + } + + public static class PositionComparator + implements IntComparator + { + private final PagesHashStrategy pagesHashStrategy; + private final LongArrayList addresses; + + public PositionComparator(PagesHashStrategy pagesHashStrategy, LongArrayList addresses) + { + this.pagesHashStrategy = pagesHashStrategy; + this.addresses = addresses; + } + + @Override + public int compare(int leftPosition, int rightPosition) + { + long leftPageAddress = addresses.getLong(leftPosition); + int leftBlockIndex = decodeSliceIndex(leftPageAddress); + int leftBlockPosition = decodePosition(leftPageAddress); + + long rightPageAddress = addresses.getLong(rightPosition); + int rightBlockIndex = decodeSliceIndex(rightPageAddress); + int rightBlockPosition = decodePosition(rightPageAddress); + + return pagesHashStrategy.compare(leftBlockIndex, leftBlockPosition, rightBlockIndex, rightBlockPosition); + } + + @Override + public int compare(Integer leftPosition, Integer rightPosition) + { + return compare(leftPosition.intValue(), rightPosition.intValue()); + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PagesHash.java b/presto-main/src/main/java/com/facebook/presto/operator/PagesHash.java index 42206f5256b59..3eb9afe20fff0 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PagesHash.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PagesHash.java @@ -26,6 +26,7 @@ import static com.facebook.presto.util.HashCollisionsEstimator.estimateNumberOfHashCollisions; import static io.airlift.slice.SizeOf.sizeOf; import static io.airlift.units.DataSize.Unit.KILOBYTE; +import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; // This implementation assumes arrays used in the hash are always a power of 2 @@ -38,7 +39,6 @@ public final class PagesHash private final int channelCount; private final int mask; private final int[] key; - private final int[] positionLinks; private final long size; // Native array of hashes for faster collisions resolution compared @@ -48,7 +48,10 @@ public final class PagesHash private final long hashCollisions; private final double expectedHashCollisions; - public PagesHash(LongArrayList addresses, PagesHashStrategy pagesHashStrategy) + public PagesHash( + LongArrayList addresses, + PagesHashStrategy pagesHashStrategy, + PositionLinks.Builder positionLinks) { this.addresses = requireNonNull(addresses, "addresses is null"); this.pagesHashStrategy = requireNonNull(pagesHashStrategy, "pagesHashStrategy is null"); @@ -61,9 +64,6 @@ public PagesHash(LongArrayList addresses, PagesHashStrategy pagesHashStrategy) key = new int[hashSize]; Arrays.fill(key, -1); - this.positionLinks = new int[addresses.size()]; - Arrays.fill(positionLinks, -1); - positionToHashes = new byte[addresses.size()]; // We will process addresses in batches, to save memory on array of hashes. @@ -102,7 +102,7 @@ public PagesHash(LongArrayList addresses, PagesHashStrategy pagesHashStrategy) if (((byte) hash) == positionToHashes[currentKey] && positionEqualsPositionIgnoreNulls(currentKey, realPosition)) { // found a slot for this key // link the new key position to the current key position - positionLinks[realPosition] = currentKey; + realPosition = positionLinks.link(realPosition, currentKey); // key[pos] updated outside of this loop break; @@ -117,7 +117,7 @@ public PagesHash(LongArrayList addresses, PagesHashStrategy pagesHashStrategy) } size = sizeOf(addresses.elements()) + pagesHashStrategy.getSizeInBytes() + - sizeOf(key) + sizeOf(positionLinks) + sizeOf(positionToHashes); + sizeOf(key) + sizeOf(positionToHashes); hashCollisions = hashCollisionsLocal; expectedHashCollisions = estimateNumberOfHashCollisions(addresses.size(), hashSize); } @@ -129,7 +129,7 @@ public final int getChannelCount() public int getPositionCount() { - return positionLinks.length; + return addresses.size(); } public long getInMemorySizeInBytes() @@ -166,14 +166,9 @@ public int getAddressIndex(int rightPosition, Page hashChannelsPage, Page allCha return -1; } - public int getNextAddressIndex(int currentAddressIndex) + public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset) { - return positionLinks[currentAddressIndex]; - } - - public void appendTo(int position, PageBuilder pageBuilder, int outputChannelOffset) - { - long pageAddress = addresses.getLong(position); + long pageAddress = addresses.getLong(toIntExact(position)); int blockIndex = decodeSliceIndex(pageAddress); int blockPosition = decodePosition(pageAddress); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PagesHashStrategy.java b/presto-main/src/main/java/com/facebook/presto/operator/PagesHashStrategy.java index 15eb3afe5520d..714db377b7105 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PagesHashStrategy.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PagesHashStrategy.java @@ -92,4 +92,6 @@ public interface PagesHashStrategy * Checks if any of the hashed columns is null */ boolean isPositionNull(int blockIndex, int blockPosition); + + int compare(int leftBlockIndex, int leftBlockPosition, int rightBlockIndex, int rightBlockPosition); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java b/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java index f2ca9a464986e..89a0e0139dbab 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java @@ -24,6 +24,7 @@ import com.facebook.presto.sql.gen.JoinCompiler.LookupSourceSupplierFactory; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory; import com.facebook.presto.sql.gen.OrderingCompiler; +import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.google.common.collect.ImmutableList; import io.airlift.log.Logger; import io.airlift.slice.Slice; @@ -384,7 +385,13 @@ public PagesHashStrategy createPagesHashStrategy(List joinChannels, Opt } // if compilation fails, use interpreter - return new SimplePagesHashStrategy(types, outputChannels.orElse(rangeList(types.size())), ImmutableList.copyOf(channels), joinChannels, hashChannel); + return new SimplePagesHashStrategy( + types, + outputChannels.orElse(rangeList(types.size())), + ImmutableList.copyOf(channels), + joinChannels, + hashChannel, + Optional.empty()); } public LookupSourceSupplier createLookupSourceSupplier( @@ -410,9 +417,13 @@ public LookupSourceSupplier createLookupSourceSupplier( // OUTER joins into NestedLoopsJoin and remove "type == INNER" condition in LocalExecutionPlanner.visitJoin() try { - LookupSourceSupplierFactory lookupSourceFactory = joinCompiler.compileLookupSourceFactory(types, joinChannels, outputChannels); + Optional sortChannel = Optional.empty(); + if (filterFunctionFactory.isPresent()) { + sortChannel = filterFunctionFactory.get().getSortChannel(); + } + LookupSourceSupplierFactory lookupSourceFactory = joinCompiler.compileLookupSourceFactory(types, joinChannels, sortChannel, outputChannels); return lookupSourceFactory.createLookupSourceSupplier( - session.toConnectorSession(), + session, valueAddresses, channels, hashChannel, @@ -429,10 +440,11 @@ public LookupSourceSupplier createLookupSourceSupplier( outputChannels.orElse(rangeList(types.size())), channels, joinChannels, - hashChannel); + hashChannel, + filterFunctionFactory.map(JoinFilterFunctionFactory::getSortChannel).orElse(Optional.empty())); return new JoinHashSupplier( - session.toConnectorSession(), + session, hashStrategy, valueAddresses, channels, diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PositionLinks.java b/presto-main/src/main/java/com/facebook/presto/operator/PositionLinks.java new file mode 100644 index 0000000000000..3baba275ab907 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/PositionLinks.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.spi.Page; + +import java.util.Optional; +import java.util.function.Function; + +/** + * This class is responsible for iterating over build rows, which have + * same values in hash columns as given probe row (but according to + * filterFunction can have non matching values on some other column). + */ +public interface PositionLinks +{ + long getSizeInBytes(); + + /** + * Initialize iteration over position links. Returns first potentially eligible + * join position starting from (and including) position argument. + * + * When there are no more position -1 is returned + */ + int start(int position, int probePosition, Page allProbeChannelsPage); + + /** + * Iterate over position links. When there are no more position -1 is returned. + */ + int next(int position, int probePosition, Page allProbeChannelsPage); + + interface Builder + { + /** + * @return value that should be used in future references to created position links + */ + int link(int left, int right); + + /** + * JoinFilterFunction has to be created and supplied for each thread using PositionLinks + * since JoinFilterFunction is not thread safe... + */ + Function, PositionLinks> build(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/SimplePagesHashStrategy.java b/presto-main/src/main/java/com/facebook/presto/operator/SimplePagesHashStrategy.java index 1e03bfc1c79d8..6f315aa839830 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/SimplePagesHashStrategy.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/SimplePagesHashStrategy.java @@ -17,6 +17,7 @@ import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.facebook.presto.type.TypeUtils; import com.google.common.collect.ImmutableList; @@ -35,8 +36,15 @@ public class SimplePagesHashStrategy private final List> channels; private final List hashChannels; private final List precomputedHashChannel; - - public SimplePagesHashStrategy(List types, List outputChannels, List> channels, List hashChannels, Optional precomputedHashChannel) + private final Optional sortChannel; + + public SimplePagesHashStrategy( + List types, + List outputChannels, + List> channels, + List hashChannels, + Optional precomputedHashChannel, + Optional sortChannel) { this.types = ImmutableList.copyOf(requireNonNull(types, "types is null")); this.outputChannels = ImmutableList.copyOf(requireNonNull(outputChannels, "outputChannels is null")); @@ -50,6 +58,7 @@ public SimplePagesHashStrategy(List types, List outputChannels, L else { this.precomputedHashChannel = null; } + this.sortChannel = requireNonNull(sortChannel, "sortChannel is null"); } @Override @@ -209,4 +218,18 @@ public boolean isPositionNull(int blockIndex, int blockPosition) } return false; } + + @Override + public int compare(int leftBlockIndex, int leftBlockPosition, int rightBlockIndex, int rightBlockPosition) + { + if (!sortChannel.isPresent()) { + throw new UnsupportedOperationException(); + } + int channel = sortChannel.get().getChannel(); + + Block leftBlock = channels.get(channel).get(leftBlockIndex); + Block rightBlock = channels.get(channel).get(rightBlockIndex); + + return types.get(channel).compareTo(leftBlock, leftBlockPosition, rightBlock, rightBlockPosition); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/SortedPositionLinks.java b/presto-main/src/main/java/com/facebook/presto/operator/SortedPositionLinks.java new file mode 100644 index 0000000000000..ca4447bd03c2a --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/SortedPositionLinks.java @@ -0,0 +1,224 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.spi.Page; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.ints.IntArrayList; +import it.unimi.dsi.fastutil.ints.IntComparator; + +import java.util.List; +import java.util.Optional; +import java.util.function.Function; + +import static com.google.common.base.Preconditions.checkState; +import static io.airlift.slice.SizeOf.sizeOf; +import static java.util.Objects.requireNonNull; + +/** + * This class assumes that lessThanFunction is a superset of the whole filtering + * condition used in a join. In other words, we can use SortedPositionLinks + * with following join condition: + * + * filterFunction_1(...) AND filterFunction_2(....) AND ... AND filterFunction_n(...) + * + * by passing any of the filterFunction_i to the SortedPositionLinks. We could not + * do that for join condition like: + * + * filterFunction_1(...) OR filterFunction_2(....) OR ... OR filterFunction_n(...) + * + * To use lessThanFunction in this class, it must be an expression in form of: + * + * f(probeColumn1, probeColumn2, ..., probeColumnN) COMPARE g(buildColumn1, ..., buildColumnN) + * + * where COMPARE is one of: < <= > >= + * + * That allows us to define an order of the elements in positionLinks (this defining which + * element is smaller) using g(...) function and to perform a binary search using + * f(probePosition) value. + */ +public final class SortedPositionLinks + implements PositionLinks +{ + public static class Builder implements PositionLinks.Builder + { + private final Int2ObjectMap positionLinks; + private final int size; + private final IntComparator comparator; + + public Builder(int size, IntComparator comparator) + { + this.comparator = comparator; + this.size = size; + positionLinks = new Int2ObjectOpenHashMap<>(); + } + + @Override + public int link(int from, int to) + { + // make sure that from value is the smaller one + if (comparator.compare(from, to) > 0) { + // _from_ is larger so, just add to current chain _to_ + List links = positionLinks.computeIfAbsent(to, key -> new IntArrayList()); + links.add(from); + return to; + } + else { + // _to_ is larger so, move the chain to _from_ + IntArrayList links = positionLinks.remove(to); + if (links == null) { + links = new IntArrayList(); + } + links.add(to); + checkState(positionLinks.putIfAbsent(from, links) == null, "sorted links is corrupted"); + return from; + } + } + + @Override + public Function, PositionLinks> build() + { + ArrayPositionLinks.Builder builder = ArrayPositionLinks.builder(size); + int[][] sortedPositionLinks = new int[size][]; + + for (Int2ObjectMap.Entry entry : positionLinks.int2ObjectEntrySet()) { + int key = entry.getIntKey(); + IntArrayList positions = entry.getValue(); + positions.sort(comparator); + + sortedPositionLinks[key] = new int[positions.size()]; + for (int i = 0; i < positions.size(); i++) { + sortedPositionLinks[key][i] = positions.get(i); + } + + // ArrayPositionsLinks.Builder::link builds position links from + // tail to head, so we must add them in descending order to have + // smallest element as a head + for (int i = positions.size() - 2; i >= 0; i--) { + builder.link(positions.get(i), positions.get(i + 1)); + } + + // add link from starting position to position links chain + if (!positions.isEmpty()) { + builder.link(key, positions.get(0)); + } + } + + return lessThanFunction -> { + checkState(lessThanFunction.isPresent(), "Using SortedPositionLinks without lessThanFunction"); + return new SortedPositionLinks( + builder.build().apply(lessThanFunction), + sortedPositionLinks, + lessThanFunction.get()); + }; + } + } + + private final PositionLinks positionLinks; + private final int[][] sortedPositionLinks; + private final JoinFilterFunction lessThanFunction; + private final long sizeInBytes; + + private SortedPositionLinks(PositionLinks positionLinks, int[][] sortedPositionLinks, JoinFilterFunction lessThanFunction) + { + this.positionLinks = requireNonNull(positionLinks, "positionLinks is null"); + this.sortedPositionLinks = requireNonNull(sortedPositionLinks, "sortedPositionLinks is null"); + this.lessThanFunction = requireNonNull(lessThanFunction, "lessThanFunction is null"); + this.sizeInBytes = positionLinks.getSizeInBytes() + sizeOf(sortedPositionLinks); + } + + public static Builder builder(int size, IntComparator comparator) + { + return new Builder(size, comparator); + } + + @Override + public int next(int position, int probePosition, Page allProbeChannelsPage) + { + int nextPosition = positionLinks.next(position, probePosition, allProbeChannelsPage); + if (nextPosition < 0) { + return -1; + } + // break a position links chain if next position should be filtered out + if (applyLessThanFunction(nextPosition, probePosition, allProbeChannelsPage)) { + return nextPosition; + } + return -1; + } + + @Override + public int start(int startingPosition, int probePosition, Page allProbeChannelsPage) + { + // check if filtering function to startingPosition + if (applyLessThanFunction(startingPosition, probePosition, allProbeChannelsPage)) { + return startingPosition; + } + + if (sortedPositionLinks[startingPosition] == null) { + return -1; + } + + int left = 0; + int right = sortedPositionLinks[startingPosition].length - 1; + + // do a binary search for the first position for which filter function applies + int offset = lowerBound(startingPosition, left, right, probePosition, allProbeChannelsPage); + if (offset < 0) { + return -1; + } + if (!applyLessThanFunction(startingPosition, offset, probePosition, allProbeChannelsPage)) { + return -1; + } + return sortedPositionLinks[startingPosition][offset]; + } + + /** + * Find the first element in position links that is NOT smaller than probePosition + */ + private int lowerBound(int startingPosition, int first, int last, int probePosition, Page allProbeChannelsPage) + { + int middle; + int step; + int count = last - first; + while (count > 0) { + step = count / 2; + middle = first + step; + if (!applyLessThanFunction(startingPosition, middle, probePosition, allProbeChannelsPage)) { + first = ++middle; + count -= step + 1; + } + else { + count = step; + } + } + return first; + } + + @Override + public long getSizeInBytes() + { + return sizeInBytes; + } + + private boolean applyLessThanFunction(int leftPosition, int leftOffset, int rightPosition, Page rightPage) + { + return applyLessThanFunction(sortedPositionLinks[leftPosition][leftOffset], rightPosition, rightPage); + } + + private boolean applyLessThanFunction(long leftPosition, int rightPosition, Page rightPage) + { + return lessThanFunction.filter((int) leftPosition, rightPosition, rightPage); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/StandardJoinFilterFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/StandardJoinFilterFunction.java index 6fddc032671e8..f4cbf407c4c3f 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/StandardJoinFilterFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/StandardJoinFilterFunction.java @@ -19,6 +19,7 @@ import it.unimi.dsi.fastutil.longs.LongArrayList; import java.util.List; +import java.util.Optional; import static com.facebook.presto.operator.SyntheticAddress.decodePosition; import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex; @@ -32,11 +33,13 @@ public class StandardJoinFilterFunction private final InternalJoinFilterFunction filterFunction; private final LongArrayList addresses; private final List pages; + private final Optional sortChannel; - public StandardJoinFilterFunction(InternalJoinFilterFunction filterFunction, LongArrayList addresses, List> channels) + public StandardJoinFilterFunction(InternalJoinFilterFunction filterFunction, LongArrayList addresses, List> channels, Optional sortChannel) { this.filterFunction = requireNonNull(filterFunction, "filterFunction can not be null"); this.addresses = requireNonNull(addresses, "addresses is null"); + this.sortChannel = requireNonNull(sortChannel, "sortChannel is null"); requireNonNull(channels, "channels can not be null"); ImmutableList.Builder pagesBuilder = ImmutableList.builder(); @@ -63,6 +66,12 @@ public boolean filter(int leftAddress, int rightPosition, Page rightPage) return filterFunction.filter(blockPosition, getLeftBlocks(blockIndex), rightPosition, rightPage.getBlocks()); } + @Override + public Optional getSortChannel() + { + return sortChannel; + } + private Block[] getLeftBlocks(int leftBlockIndex) { if (pages.isEmpty()) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 216c70d8cc821..a304a2a175a94 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -50,6 +50,7 @@ public static class ProcessingOptimization private boolean distributedIndexJoinsEnabled; private boolean distributedJoinsEnabled = true; private boolean colocatedJoinsEnabled; + private boolean fastInequalityJoins = true; private boolean reorderJoins; private boolean redistributeWrites = true; private boolean optimizeMetadataQueries; @@ -164,6 +165,19 @@ public FeaturesConfig setColocatedJoinsEnabled(boolean colocatedJoinsEnabled) return this; } + @Config("fast-inequality-joins") + @ConfigDescription("Experimental: Use faster handling of inequality joins if it is possible") + public FeaturesConfig setFastInequalityJoins(boolean fastInequalityJoins) + { + this.fastInequalityJoins = fastInequalityJoins; + return this; + } + + public boolean isFastInequalityJoins() + { + return fastInequalityJoins; + } + public boolean isJoinReorderingEnabled() { return reorderJoins; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java index 30ce3181c6f07..c013c0812bf7b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.gen; +import com.facebook.presto.Session; import com.facebook.presto.bytecode.BytecodeBlock; import com.facebook.presto.bytecode.BytecodeNode; import com.facebook.presto.bytecode.ClassDefinition; @@ -31,7 +32,6 @@ import com.facebook.presto.operator.LookupSourceSupplier; import com.facebook.presto.operator.PagesHash; import com.facebook.presto.operator.PagesHashStrategy; -import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; @@ -39,6 +39,7 @@ import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory; +import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.google.common.base.Throwables; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; @@ -71,6 +72,7 @@ import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantLong; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantNull; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantTrue; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.notEqual; import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -87,7 +89,7 @@ public class JoinCompiler public LookupSourceSupplierFactory load(CacheKey key) throws Exception { - return internalCompileLookupSourceFactory(key.getTypes(), key.getOutputChannels(), key.getJoinChannels()); + return internalCompileLookupSourceFactory(key.getTypes(), key.getOutputChannels(), key.getJoinChannels(), key.getSortChannel()); } }); @@ -100,13 +102,13 @@ public LookupSourceSupplierFactory load(CacheKey key) public Class load(CacheKey key) throws Exception { - return internalCompileHashStrategy(key.getTypes(), key.getOutputChannels(), key.getJoinChannels()); + return internalCompileHashStrategy(key.getTypes(), key.getOutputChannels(), key.getJoinChannels(), key.getSortChannel()); } }); - public LookupSourceSupplierFactory compileLookupSourceFactory(List types, List joinChannels) + public LookupSourceSupplierFactory compileLookupSourceFactory(List types, List joinChannels, Optional sortChannel) { - return compileLookupSourceFactory(types, joinChannels, Optional.empty()); + return compileLookupSourceFactory(types, joinChannels, sortChannel, Optional.empty()); } @Managed @@ -123,13 +125,14 @@ public CacheStatsMBean getHashStrategiesStats() return new CacheStatsMBean(hashStrategies); } - public LookupSourceSupplierFactory compileLookupSourceFactory(List types, List joinChannels, Optional> outputChannels) + public LookupSourceSupplierFactory compileLookupSourceFactory(List types, List joinChannels, Optional sortChannel, Optional> outputChannels) { try { return lookupSourceFactories.get(new CacheKey( types, outputChannels.orElse(rangeList(types.size())), - joinChannels)); + joinChannels, + sortChannel)); } catch (ExecutionException | UncheckedExecutionException | ExecutionError e) { throw Throwables.propagate(e.getCause()); @@ -151,7 +154,8 @@ public PagesHashStrategyFactory compilePagesHashStrategyFactory(List types return new PagesHashStrategyFactory(hashStrategies.get(new CacheKey( types, outputChannels.orElse(rangeList(types.size())), - joinChannels))); + joinChannels, + Optional.empty()))); } catch (ExecutionException | UncheckedExecutionException | ExecutionError e) { throw Throwables.propagate(e.getCause()); @@ -165,9 +169,9 @@ private List rangeList(int endExclusive) .collect(toImmutableList()); } - private LookupSourceSupplierFactory internalCompileLookupSourceFactory(List types, List outputChannels, List joinChannels) + private LookupSourceSupplierFactory internalCompileLookupSourceFactory(List types, List outputChannels, List joinChannels, Optional sortChannel) { - Class pagesHashStrategyClass = internalCompileHashStrategy(types, outputChannels, joinChannels); + Class pagesHashStrategyClass = internalCompileHashStrategy(types, outputChannels, joinChannels, sortChannel); Class joinHashSupplierClass = IsolatedClass.isolateClass( new DynamicClassLoader(getClass().getClassLoader()), @@ -179,7 +183,7 @@ private LookupSourceSupplierFactory internalCompileLookupSourceFactory(List internalCompileHashStrategy(List types, List outputChannels, List joinChannels) + private Class internalCompileHashStrategy(List types, List outputChannels, List joinChannels, Optional sortChannel) { CallSiteBinder callSiteBinder = new CallSiteBinder(); @@ -217,6 +221,7 @@ private Class internalCompileHashStrategy(List types, + List channelFields, + Optional sortChannel) + { + Parameter leftBlockIndex = arg("leftBlockIndex", int.class); + Parameter leftBlockPosition = arg("leftBlockPosition", int.class); + Parameter rightBlockIndex = arg("rightBlockIndex", int.class); + Parameter rightBlockPosition = arg("rightBlockPosition", int.class); + MethodDefinition compareMethod = classDefinition.declareMethod( + a(PUBLIC), + "compare", + type(int.class), + leftBlockIndex, + leftBlockPosition, + rightBlockIndex, + rightBlockPosition); + + if (!sortChannel.isPresent()) { + compareMethod.getBody() + .append(newInstance(UnsupportedOperationException.class)) + .throwObject(); + return; + } + + Variable thisVariable = compareMethod.getThis(); + + int index = sortChannel.get().getChannel(); + BytecodeExpression type = constantType(callSiteBinder, types.get(index)); + + BytecodeExpression leftBlock = thisVariable + .getField(channelFields.get(index)) + .invoke("get", Object.class, leftBlockIndex) + .cast(Block.class); + + BytecodeExpression rightBlock = thisVariable + .getField(channelFields.get(index)) + .invoke("get", Object.class, rightBlockIndex) + .cast(Block.class); + + BytecodeNode comparison = type.invoke("compareTo", int.class, leftBlock, leftBlockPosition, rightBlock, rightBlockPosition).ret(); + + compareMethod + .getBody() + .append(comparison); + } + private static BytecodeNode typeEquals( BytecodeExpression type, BytecodeExpression leftBlock, @@ -726,7 +780,7 @@ public LookupSourceSupplierFactory(Class joinHas { this.pagesHashStrategyFactory = pagesHashStrategyFactory; try { - constructor = joinHashSupplierClass.getConstructor(ConnectorSession.class, PagesHashStrategy.class, LongArrayList.class, List.class, Optional.class); + constructor = joinHashSupplierClass.getConstructor(Session.class, PagesHashStrategy.class, LongArrayList.class, List.class, Optional.class); } catch (NoSuchMethodException e) { throw Throwables.propagate(e); @@ -734,7 +788,7 @@ public LookupSourceSupplierFactory(Class joinHas } public LookupSourceSupplier createLookupSourceSupplier( - ConnectorSession session, + Session session, LongArrayList addresses, List> channels, Optional hashChannel, @@ -780,12 +834,14 @@ private static final class CacheKey private final List types; private final List outputChannels; private final List joinChannels; + private final Optional sortChannel; - private CacheKey(List types, List outputChannels, List joinChannels) + private CacheKey(List types, List outputChannels, List joinChannels, Optional sortChannel) { this.types = ImmutableList.copyOf(requireNonNull(types, "types is null")); this.outputChannels = ImmutableList.copyOf(requireNonNull(outputChannels, "outputChannels is null")); this.joinChannels = ImmutableList.copyOf(requireNonNull(joinChannels, "joinChannels is null")); + this.sortChannel = requireNonNull(sortChannel, "sortChannel is null"); } private List getTypes() @@ -803,10 +859,15 @@ private List getJoinChannels() return joinChannels; } + public Optional getSortChannel() + { + return sortChannel; + } + @Override public int hashCode() { - return Objects.hash(types, outputChannels, joinChannels); + return Objects.hash(types, outputChannels, joinChannels, sortChannel); } @Override @@ -821,7 +882,8 @@ public boolean equals(Object obj) CacheKey other = (CacheKey) obj; return Objects.equals(this.types, other.types) && Objects.equals(this.outputChannels, other.outputChannels) && - Objects.equals(this.joinChannels, other.joinChannels); + Objects.equals(this.joinChannels, other.joinChannels) && + Objects.equals(this.sortChannel, other.sortChannel); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java index e73a10af7a931..27cb1976c38e6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java @@ -29,6 +29,7 @@ import com.facebook.presto.operator.StandardJoinFilterFunction; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.LambdaDefinitionExpression; import com.facebook.presto.sql.relational.RowExpression; @@ -52,6 +53,7 @@ import java.lang.reflect.Constructor; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.Set; import static com.facebook.presto.bytecode.Access.FINAL; @@ -90,7 +92,7 @@ public JoinFilterFunctionCompiler(Metadata metadata) public JoinFilterFunctionFactory load(JoinFilterCacheKey key) throws Exception { - return internalCompileFilterFunctionFactory(key.getFilter(), key.getLeftBlocksSize()); + return internalCompileFilterFunctionFactory(key.getFilter(), key.getLeftBlocksSize(), key.getSortChannel()); } }); @@ -101,15 +103,15 @@ public CacheStatsMBean getJoinFilterFunctionFactoryStats() return new CacheStatsMBean(joinFilterFunctionFactories); } - public JoinFilterFunctionFactory compileJoinFilterFunction(RowExpression filter, int leftBlocksSize) + public JoinFilterFunctionFactory compileJoinFilterFunction(RowExpression filter, int leftBlocksSize, Optional sortChannel) { - return joinFilterFunctionFactories.getUnchecked(new JoinFilterCacheKey(filter, leftBlocksSize)); + return joinFilterFunctionFactories.getUnchecked(new JoinFilterCacheKey(filter, leftBlocksSize, sortChannel)); } - private JoinFilterFunctionFactory internalCompileFilterFunctionFactory(RowExpression filterExpression, int leftBlocksSize) + private JoinFilterFunctionFactory internalCompileFilterFunctionFactory(RowExpression filterExpression, int leftBlocksSize, Optional sortChannel) { Class internalJoinFilterFunction = compileInternalJoinFilterFunction(filterExpression, leftBlocksSize); - return new IsolatedJoinFilterFunctionFactory(internalJoinFilterFunction); + return new IsolatedJoinFilterFunctionFactory(internalJoinFilterFunction, sortChannel); } private Class compileInternalJoinFilterFunction(RowExpression filterExpression, int leftBlocksSize) @@ -143,6 +145,7 @@ private void generateMethods(ClassDefinition classDefinition, CallSiteBinder cal CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder); FieldDefinition sessionField = classDefinition.declareField(a(PRIVATE, FINAL), "session", ConnectorSession.class); + generateFilterMethod(classDefinition, callSiteBinder, cachedInstanceBinder, filter, leftBlocksSize, sessionField); generateConstructor(classDefinition, sessionField, cachedInstanceBinder); } @@ -160,9 +163,7 @@ private static void generateConstructor(ClassDefinition classDefinition, FieldDe .invokeConstructor(Object.class); body.append(thisVariable.setField(sessionField, sessionParameter)); - cachedInstanceBinder.generateInitializations(thisVariable, body); - body.ret(); } @@ -292,10 +293,14 @@ private static void generateToString(ClassDefinition classDefinition, CallSiteBi .retObject(); } - @FunctionalInterface public interface JoinFilterFunctionFactory { JoinFilterFunction create(ConnectorSession session, LongArrayList addresses, List> channels); + + default Optional getSortChannel() + { + return Optional.empty(); + } } private static RowExpressionVisitor fieldReferenceCompiler( @@ -316,11 +321,13 @@ private static final class JoinFilterCacheKey { private final RowExpression filter; private final int leftBlocksSize; + private final Optional sortChannel; - public JoinFilterCacheKey(RowExpression filter, int leftBlocksSize) + public JoinFilterCacheKey(RowExpression filter, int leftBlocksSize, Optional sortChannel) { this.filter = requireNonNull(filter, "filter can not be null"); this.leftBlocksSize = leftBlocksSize; + this.sortChannel = requireNonNull(sortChannel, "sortChannel can not be null"); } public RowExpression getFilter() @@ -333,6 +340,11 @@ public int getLeftBlocksSize() return leftBlocksSize; } + public Optional getSortChannel() + { + return sortChannel; + } + @Override public boolean equals(Object o) { @@ -368,9 +380,11 @@ private static class IsolatedJoinFilterFunctionFactory { private final Constructor internalJoinFilterFunctionConstructor; private final Constructor isolatedJoinFilterFunctionConstructor; + private final Optional sortChannel; - public IsolatedJoinFilterFunctionFactory(Class internalJoinFilterFunction) + public IsolatedJoinFilterFunctionFactory(Class internalJoinFilterFunction, Optional sortChannel) { + this.sortChannel = sortChannel; try { internalJoinFilterFunctionConstructor = internalJoinFilterFunction .getConstructor(ConnectorSession.class); @@ -379,7 +393,7 @@ public IsolatedJoinFilterFunctionFactory(Class getSortChannel() + { + return sortChannel; + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index 03d05f7278924..812419d3586a4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -103,6 +103,7 @@ import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; +import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; @@ -1600,6 +1601,8 @@ private JoinFilterFunctionFactory compileJoinFilterFunction( Expression rewrittenFilter = new SymbolToInputRewriter(joinSourcesLayout).rewrite(filterExpression); + Optional sortChannel = SortExpressionExtractor.extractSortExpression(buildLayout, rewrittenFilter); + IdentityLinkedHashMap expressionTypes = getExpressionTypesFromInput( session, metadata, @@ -1609,7 +1612,7 @@ private JoinFilterFunctionFactory compileJoinFilterFunction( emptyList() /* parameters have already been replaced */); RowExpression translatedFilter = toRowExpression(rewrittenFilter, expressionTypes); - return joinFilterFunctionCompiler.compileJoinFilterFunction(translatedFilter, buildLayout.size()); + return joinFilterFunctionCompiler.compileJoinFilterFunction(translatedFilter, buildLayout.size(), sortChannel); } private OperatorFactory createLookupJoin( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionExtractor.java new file mode 100644 index 0000000000000..2812022bf1b6e --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionExtractor.java @@ -0,0 +1,164 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.sql.tree.AstVisitor; +import com.facebook.presto.sql.tree.ComparisonExpression; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.FieldReference; +import com.facebook.presto.sql.tree.Node; +import com.google.common.collect.ImmutableSet; + +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +/** + * Currently this class handles only simple expressions like: + * + * A.a < B.x. + * + * It could be extended to handle any expressions like: + * + * A.a * sin(A.b) / log(B.x) < cos(B.z) + * + * by transforming it to: + * + * f(A.a, A.b) < g(B.x, B.z) + * + * Where f(...) and g(...) would be some functions/expressions. That + * would allow us to perform binary search on arbitrary complex expressions + * by sorting position links according to the result of f(...) function. + */ +public final class SortExpressionExtractor +{ + private SortExpressionExtractor() {} + + public static Optional extractSortExpression(Map buildLayout, Expression filter) + { + Set buildFields = ImmutableSet.copyOf(buildLayout.values()); + if (filter instanceof ComparisonExpression) { + ComparisonExpression comparison = (ComparisonExpression) filter; + switch (comparison.getType()) { + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + Optional sortChannel = asBuildFieldReference(buildFields, comparison.getRight()); + boolean hasBuildReferencesOnOtherSide = hasBuildFieldReference(buildFields, comparison.getLeft()); + if (!sortChannel.isPresent()) { + sortChannel = asBuildFieldReference(buildFields, comparison.getLeft()); + hasBuildReferencesOnOtherSide = hasBuildFieldReference(buildFields, comparison.getRight()); + } + if (sortChannel.isPresent() && !hasBuildReferencesOnOtherSide) { + return Optional.of(new SortExpression(sortChannel.get())); + } + return Optional.empty(); + default: + return Optional.empty(); + } + } + + return Optional.empty(); + } + + private static Optional asBuildFieldReference(Set buildLayout, Expression expression) + { + if (expression instanceof FieldReference) { + FieldReference field = (FieldReference) expression; + if (buildLayout.contains(field.getFieldIndex())) { + return Optional.of(field.getFieldIndex()); + } + } + return Optional.empty(); + } + + private static boolean hasBuildFieldReference(Set buildLayout, Expression expression) + { + return new BuildFieldReferenceFinder(buildLayout).process(expression); + } + + private static class BuildFieldReferenceFinder + extends AstVisitor + { + private final Set buildLayout; + + public BuildFieldReferenceFinder(Set buildLayout) + { + this.buildLayout = ImmutableSet.copyOf(requireNonNull(buildLayout, "buildLayout is null")); + } + + @Override + protected Boolean visitNode(Node node, Void context) + { + for (Node child : node.getChildren()) { + if (process(child, context)) { + return true; + } + } + return false; + } + + @Override + protected Boolean visitFieldReference(FieldReference fieldReference, Void context) + { + return buildLayout.contains(fieldReference.getFieldIndex()); + } + } + + public static class SortExpression + { + private final int channel; + + public SortExpression(int channel) + { + this.channel = channel; + } + + public int getChannel() + { + return channel; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + SortExpression other = (SortExpression) obj; + return Objects.equals(this.channel, other.channel); + } + + @Override + public int hashCode() + { + return Objects.hash(channel); + } + + public String toString() + { + return toStringHelper(this) + .add("channel", channel) + .toString(); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashBuildAndJoinOperators.java b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashBuildAndJoinOperators.java index c3d218a81f9be..12a6a193891f5 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashBuildAndJoinOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashBuildAndJoinOperators.java @@ -83,6 +83,9 @@ public static class BuildContext @Param({"false", "true"}) protected boolean buildHashEnabled; + @Param({"1", "5"}) + protected int buildRowsRepetition; + protected ExecutorService executor; protected List buildPages; protected Optional hashChannel; @@ -139,10 +142,11 @@ protected void initializeBuildPages() { RowPagesBuilder buildPagesBuilder = rowPagesBuilder(buildHashEnabled, hashChannels, ImmutableList.of(VARCHAR, BIGINT, BIGINT)); + int maxValue = BUILD_ROWS_NUMBER / buildRowsRepetition + 40; int rows = 0; while (rows < BUILD_ROWS_NUMBER) { int newRows = Math.min(BUILD_ROWS_NUMBER - rows, ROWS_PER_PAGE); - buildPagesBuilder.addSequencePage(newRows, rows + 20, rows + 30, rows + 40); + buildPagesBuilder.addSequencePage(newRows, (rows + 20) % maxValue, (rows + 30) % maxValue, (rows + 40) % maxValue); buildPagesBuilder.pageBreak(); rows += newRows; } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestHashJoinOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestHashJoinOperator.java index 75305b8b678c3..f6c6550be3ee9 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestHashJoinOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestHashJoinOperator.java @@ -679,7 +679,7 @@ private static List getHashChannels(RowPagesBuilder probe, RowPagesBuil private static LookupSourceFactory buildHash(boolean parallelBuild, TaskContext taskContext, List hashChannels, RowPagesBuilder buildPages, Optional filterFunction) { Optional filterFunctionFactory = filterFunction - .map(function -> ((session, addresses, channels) -> new StandardJoinFilterFunction(function, addresses, channels))); + .map(function -> (session, addresses, channels) -> new StandardJoinFilterFunction(function, addresses, channels, Optional.empty())); int partitionCount = parallelBuild ? PARTITION_COUNT : 1; LocalExchange localExchange = new LocalExchange(FIXED_HASH_DISTRIBUTION, partitionCount, buildPages.getTypes(), hashChannels, buildPages.getHashChannel()); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestPositionLinks.java b/presto-main/src/test/java/com/facebook/presto/operator/TestPositionLinks.java new file mode 100644 index 0000000000000..6fcaec594c499 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestPositionLinks.java @@ -0,0 +1,149 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.RowPagesBuilder; +import com.facebook.presto.spi.Page; +import it.unimi.dsi.fastutil.ints.IntComparator; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.google.common.collect.Iterables.getOnlyElement; +import static org.testng.Assert.assertEquals; + +public class TestPositionLinks +{ + private static final Page TEST_PAGE = getOnlyElement(RowPagesBuilder.rowPagesBuilder(BIGINT).addSequencePage(20, 0).build()); + + @Test + public void testArrayPositionLinks() + { + PositionLinks.Builder builder = ArrayPositionLinks.builder(1000); + + assertEquals(builder.link(1, 0), 1); + assertEquals(builder.link(2, 1), 2); + assertEquals(builder.link(3, 2), 3); + + assertEquals(builder.link(11, 10), 11); + assertEquals(builder.link(12, 11), 12); + + PositionLinks positionLinks = builder.build().apply(Optional.empty()); + + assertEquals(positionLinks.start(3, 0, TEST_PAGE), 3); + assertEquals(positionLinks.next(3, 0, TEST_PAGE), 2); + assertEquals(positionLinks.next(2, 0, TEST_PAGE), 1); + assertEquals(positionLinks.next(1, 0, TEST_PAGE), 0); + + assertEquals(positionLinks.start(4, 0, TEST_PAGE), 4); + assertEquals(positionLinks.next(4, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(12, 0, TEST_PAGE), 12); + assertEquals(positionLinks.next(12, 0, TEST_PAGE), 11); + assertEquals(positionLinks.next(11, 0, TEST_PAGE), 10); + } + + @Test + public void testSortedPositionLinks() + { + JoinFilterFunction filterFunction = new JoinFilterFunction() + { + @Override + public boolean filter(int leftAddress, int rightPosition, Page rightPage) + { + return BIGINT.getLong(rightPage.getBlock(0), leftAddress) > 4; + } + + @Override + public Optional getSortChannel() + { + throw new UnsupportedOperationException(); + } + }; + + PositionLinks.Builder builder = buildSortedPositionLinks(); + PositionLinks positionLinks = builder.build().apply(Optional.of(filterFunction)); + + assertEquals(positionLinks.start(0, 0, TEST_PAGE), 5); + assertEquals(positionLinks.next(5, 0, TEST_PAGE), 6); + assertEquals(positionLinks.next(6, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(10, 0, TEST_PAGE), 10); + assertEquals(positionLinks.next(10, 0, TEST_PAGE), 11); + assertEquals(positionLinks.next(11, 0, TEST_PAGE), 12); + assertEquals(positionLinks.next(12, 0, TEST_PAGE), -1); + } + + @Test + public void testReverseSortedPositionLinks() + { + JoinFilterFunction filterFunction = new JoinFilterFunction() + { + @Override + public boolean filter(int leftAddress, int rightPosition, Page rightPage) + { + return BIGINT.getLong(rightPage.getBlock(0), leftAddress) < 4; + } + + @Override + public Optional getSortChannel() + { + throw new UnsupportedOperationException(); + } + }; + + PositionLinks.Builder builder = buildSortedPositionLinks(); + PositionLinks positionLinks = builder.build().apply(Optional.of(filterFunction)); + + assertEquals(positionLinks.start(0, 0, TEST_PAGE), 0); + assertEquals(positionLinks.next(0, 0, TEST_PAGE), 1); + assertEquals(positionLinks.next(1, 0, TEST_PAGE), 2); + assertEquals(positionLinks.next(2, 0, TEST_PAGE), 3); + assertEquals(positionLinks.next(3, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(10, 0, TEST_PAGE), -1); + } + + private static PositionLinks.Builder buildSortedPositionLinks() + { + SortedPositionLinks.Builder builder = SortedPositionLinks.builder( + 1000, + new IntComparator() { + @Override + public int compare(int left, int right) + { + return BIGINT.compareTo(TEST_PAGE.getBlock(0), left, TEST_PAGE.getBlock(0), right); + } + + @Override + public int compare(Integer left, Integer right) + { + return compare(left.intValue(), right.intValue()); + } + }); + + assertEquals(builder.link(4, 5), 4); + assertEquals(builder.link(6, 4), 4); + assertEquals(builder.link(2, 4), 2); + assertEquals(builder.link(3, 2), 2); + assertEquals(builder.link(0, 2), 0); + assertEquals(builder.link(1, 0), 0); + + assertEquals(builder.link(10, 11), 10); + assertEquals(builder.link(12, 10), 10); + + return builder; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 0306ad7944bbb..4ea44f191d66c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -40,6 +40,7 @@ public void testDefaults() .setResourceGroupsEnabled(false) .setDistributedIndexJoinsEnabled(false) .setDistributedJoinsEnabled(true) + .setFastInequalityJoins(true) .setColocatedJoinsEnabled(false) .setJoinReorderingEnabled(false) .setRedistributeWrites(true) @@ -78,6 +79,7 @@ public void testExplicitPropertyMappings() .put("deprecated.legacy-map-subscript", "true") .put("distributed-index-joins-enabled", "true") .put("distributed-joins-enabled", "false") + .put("fast-inequality-joins", "false") .put("colocated-joins-enabled", "true") .put("reorder-joins", "true") .put("redistribute-writes", "false") @@ -107,6 +109,7 @@ public void testExplicitPropertyMappings() .put("deprecated.legacy-map-subscript", "true") .put("distributed-index-joins-enabled", "true") .put("distributed-joins-enabled", "false") + .put("fast-inequality-joins", "false") .put("colocated-joins-enabled", "true") .put("reorder-joins", "true") .put("redistribute-writes", "false") @@ -134,6 +137,7 @@ public void testExplicitPropertyMappings() .setIterativeOptimizerTimeout(new Duration(10, SECONDS)) .setDistributedIndexJoinsEnabled(true) .setDistributedJoinsEnabled(false) + .setFastInequalityJoins(false) .setColocatedJoinsEnabled(true) .setJoinReorderingEnabled(true) .setRedistributeWrites(false) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestJoinCompiler.java b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestJoinCompiler.java index 394f18af5793b..1e51c4f9fd1d6 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestJoinCompiler.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestJoinCompiler.java @@ -188,7 +188,7 @@ public void testMultiChannel(boolean hashEnabled) PagesHashStrategyFactory pagesHashStrategyFactory = joinCompiler.compilePagesHashStrategyFactory(types, joinChannels, Optional.of(outputChannels)); PagesHashStrategy hashStrategy = pagesHashStrategyFactory.createPagesHashStrategy(channels, hashChannel); // todo add tests for filter function - PagesHashStrategy expectedHashStrategy = new SimplePagesHashStrategy(types, outputChannels, channels, joinChannels, hashChannel); + PagesHashStrategy expectedHashStrategy = new SimplePagesHashStrategy(types, outputChannels, channels, joinChannels, hashChannel, Optional.empty()); // verify channel count assertEquals(hashStrategy.getChannelCount(), outputChannels.size()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestJoinProbeCompiler.java b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestJoinProbeCompiler.java index ad0f74034ad52..217287a031915 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestJoinProbeCompiler.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestJoinProbeCompiler.java @@ -84,7 +84,7 @@ public void testSingleChannel(boolean hashEnabled) ImmutableList types = ImmutableList.of(VARCHAR, DOUBLE); ImmutableList outputTypes = ImmutableList.of(VARCHAR); List outputChannels = ImmutableList.of(0); - LookupSourceSupplierFactory lookupSourceSupplierFactory = joinCompiler.compileLookupSourceFactory(types, Ints.asList(0)); + LookupSourceSupplierFactory lookupSourceSupplierFactory = joinCompiler.compileLookupSourceFactory(types, Ints.asList(0), Optional.empty()); // crate hash strategy with a single channel blocks -- make sure there is some overlap in values List varcharChannel = ImmutableList.of( @@ -118,7 +118,7 @@ public void testSingleChannel(boolean hashEnabled) outputTypes = ImmutableList.of(VARCHAR, BigintType.BIGINT); } LookupSource lookupSource = lookupSourceSupplierFactory.createLookupSourceSupplier( - taskContext.getSession().toConnectorSession(), + taskContext.getSession(), addresses, channels, hashChannel, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSortExpressionExtractor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSortExpressionExtractor.java new file mode 100644 index 0000000000000..63c6c181bbd0b --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSortExpressionExtractor.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; +import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; +import com.facebook.presto.sql.tree.ComparisonExpression; +import com.facebook.presto.sql.tree.ComparisonExpressionType; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.FieldReference; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.QualifiedName; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.Optional; + +import static org.testng.AssertJUnit.assertEquals; + +public class TestSortExpressionExtractor +{ + private static final Map BUILD_LAYOUT = ImmutableMap.of( + new Symbol("b1"), 1, + new Symbol("b2"), 2); + + @Test + public void testGetSortExpression() + { + assertGetSortExpression( + new ComparisonExpression( + ComparisonExpressionType.GREATER_THAN, + new FieldReference(11), + new FieldReference(1)), + 1); + + assertGetSortExpression( + new ComparisonExpression( + ComparisonExpressionType.LESS_THAN_OR_EQUAL, + new FieldReference(2), + new FieldReference(11)), + 2); + + assertGetSortExpression( + new ComparisonExpression( + ComparisonExpressionType.GREATER_THAN, + new FieldReference(2), + new FieldReference(11)), + 2); + + assertGetSortExpression( + new ComparisonExpression( + ComparisonExpressionType.GREATER_THAN, + new FieldReference(1), + new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Type.ADD, new FieldReference(2), new FieldReference(11)))); + + assertGetSortExpression( + new ComparisonExpression( + ComparisonExpressionType.GREATER_THAN, + new FunctionCall(QualifiedName.of("sin"), ImmutableList.of(new FieldReference(1))), + new FieldReference(11))); + } + + private static void assertGetSortExpression(Expression expression) + { + Optional actual = SortExpressionExtractor.extractSortExpression(BUILD_LAYOUT, expression); + assertEquals(Optional.empty(), actual); + } + + private static void assertGetSortExpression(Expression expression, int expectedChannel) + { + Optional expected = Optional.of(new SortExpression(expectedChannel)); + Optional actual = SortExpressionExtractor.extractSortExpression(BUILD_LAYOUT, expression); + assertEquals(expected, actual); + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index 36b6f7d8e2272..4042db12b06d6 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -2107,6 +2107,31 @@ public void testHistogram() assertQuery("SELECT lines, COUNT(*) FROM (SELECT orderkey, COUNT(*) lines FROM lineitem GROUP BY orderkey) U GROUP BY lines"); } + @Test + public void testJoinWithLessThanInJoinClause() + throws Exception + { + assertQuery("SELECT n.nationkey, r.regionkey FROM region r JOIN nation n ON n.regionkey = r.regionkey AND n.name < r.name"); + assertQuery("SELECT l.suppkey, n.nationkey, l.partkey, n.regionkey FROM nation n JOIN lineitem l ON l.suppkey = n.nationkey AND l.partkey < n.regionkey"); + } + + @Test + public void testJoinWithGreaterThanInJoinClause() + throws Exception + { + assertQuery("SELECT n.nationkey, r.regionkey FROM region r JOIN nation n ON n.regionkey = r.regionkey AND n.name > r.name AND r.regionkey = 0"); + assertQuery("SELECT l.suppkey, n.nationkey, l.partkey, n.regionkey FROM nation n JOIN lineitem l ON l.suppkey = n.nationkey AND l.partkey > n.regionkey"); + } + + @Test + public void testJoinWithLessThanOnDatesInJoinClause() + throws Exception + { + assertQuery( + "SELECT o.orderkey, o.orderdate, l.shipdate FROM orders o JOIN lineitem l ON l.orderkey = o.orderkey AND l.shipdate < o.orderdate + INTERVAL '10' DAY", + "SELECT o.orderkey, o.orderdate, l.shipdate FROM orders o JOIN lineitem l ON l.orderkey = o.orderkey AND l.shipdate < DATEADD('DAY', 10, o.orderdate)"); + } + @Test public void testSimpleJoin() {