diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ConsistentHashingNodeProvider.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ConsistentHashingNodeProvider.java index 4907e7bf5a6ef..ffa831b35a20a 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ConsistentHashingNodeProvider.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ConsistentHashingNodeProvider.java @@ -20,9 +20,12 @@ import com.google.common.hash.HashFunction; import java.util.Collection; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.NavigableMap; +import java.util.Set; +import java.util.SortedMap; import java.util.TreeMap; import static com.facebook.presto.common.type.encoding.StringUtils.UTF_8; @@ -35,6 +38,7 @@ public class ConsistentHashingNodeProvider { private static final HashFunction HASH_FUNCTION = murmur3_32(); private final NavigableMap candidates; + private final int nodeCount; static ConsistentHashingNodeProvider create(Collection nodes, int weight) { @@ -44,27 +48,49 @@ static ConsistentHashingNodeProvider create(Collection nodes, int activeNodesByConsistentHashing.put(murmur3_32().hashString(format("%s%d", node.getNodeIdentifier(), i), UTF_8).asInt(), node); } } - return new ConsistentHashingNodeProvider(activeNodesByConsistentHashing); + return new ConsistentHashingNodeProvider(activeNodesByConsistentHashing, nodes.size()); } - public ConsistentHashingNodeProvider(NavigableMap candidates) + private ConsistentHashingNodeProvider(NavigableMap candidates, int nodeCount) { this.candidates = requireNonNull(candidates, "candidates is null"); + this.nodeCount = nodeCount; } @Override public List get(String key, int count) { + if (count > nodeCount) { + count = nodeCount; + } ImmutableList.Builder nodes = ImmutableList.builder(); - for (int i = 0; i < count; i++) { - int hashKey = HASH_FUNCTION.hashString(format("%s%d", key, i), UTF_8).asInt(); - Map.Entry entry = candidates.ceilingEntry(hashKey); - if (entry != null) { - nodes.add(candidates.ceilingEntry(hashKey).getValue().getHostAndPort()); - } - else { - nodes.add(candidates.firstEntry().getValue().getHostAndPort()); + Set unique = new HashSet<>(); + int hashKey = HASH_FUNCTION.hashString(format("%s", key), UTF_8).asInt(); + Map.Entry entry = candidates.ceilingEntry(hashKey); + HostAddress candidate; + SortedMap nextEntries; + if (entry != null) { + candidate = entry.getValue().getHostAndPort(); + nextEntries = candidates.tailMap(entry.getKey(), false); + } + else { + candidate = candidates.firstEntry().getValue().getHostAndPort(); + nextEntries = candidates.tailMap(candidates.firstKey(), false); + } + unique.add(candidate); + nodes.add(candidate); + while (unique.size() < count) { + for (Map.Entry next : nextEntries.entrySet()) { + candidate = next.getValue().getHostAndPort(); + if (!unique.contains(candidate)) { + unique.add(candidate); + nodes.add(candidate); + if (unique.size() == count) { + break; + } + } } + nextEntries = candidates; } return nodes.build(); } diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ModularHashingNodeProvider.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ModularHashingNodeProvider.java index 827d9773e074f..5b3809fbe2e0a 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ModularHashingNodeProvider.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ModularHashingNodeProvider.java @@ -44,6 +44,9 @@ public List get(String identifier, int count) int mod = identifier.hashCode() % size; int position = mod < 0 ? mod + size : mod; List chosenCandidates = new ArrayList<>(); + if (count > size) { + count = size; + } for (int i = 0; i < count && i < sortedCandidates.size(); i++) { chosenCandidates.add(sortedCandidates.get((position + i) % size).getHostAndPort()); } diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TestNodeScheduler.java b/presto-main/src/test/java/com/facebook/presto/execution/TestNodeScheduler.java index 86d3c52e2d5f2..a3ba1499eb429 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/TestNodeScheduler.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/TestNodeScheduler.java @@ -94,6 +94,7 @@ import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.HARD_AFFINITY; import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.NO_PREFERENCE; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; @@ -611,11 +612,11 @@ public void testAffinityAssignmentWithConsistentHashing() // In setup node 1-3 are added to node manager SplitPlacementResult splitPlacementResult = nodeSelector.computeAssignments(splits, ImmutableList.of()); assertEquals(splitPlacementResult.getAssignments().keySet().size(), 3); - // node1: split 2, 3, 5, 6, 9 + // node1: split 1, 3, 4, 5, 6, 7, 8, 9 Collection node1Splits = splitPlacementResult.getAssignments().get(node1).stream().map(Split::getConnectorSplit).collect(toImmutableSet()); - // node2: split 1, 7 + // node2: 0 Collection node2Splits = splitPlacementResult.getAssignments().get(node2).stream().map(Split::getConnectorSplit).collect(toImmutableSet()); - // node3: split 0, 4, 8 + // node3: split 2 Collection node3Splits = splitPlacementResult.getAssignments().get(node3).stream().map(Split::getConnectorSplit).collect(toImmutableSet()); // Scheduling the same splits on the same set of nodes should give the same assignment @@ -669,22 +670,22 @@ public void testAffinityAssignmentWithConsistentHashingWithVirtualNodes() // entry5 ( 2145381619): node1 SplitPlacementResult splitPlacementResult = nodeSelector.computeAssignments(splits, ImmutableList.of()); // hashing value for splits: - // 0: -1879591379 -> entry1 -> node3 - // 1: -2031875777 -> entry0 -> node2 - // 2: -163077544 -> entry3 -> node3 - // 3: 749129358 -> entry3 -> node3 - // 4: -1784631546 -> entry1 -> node3 - // 5: -118156056 -> entry3 -> node3 - // 6: 388471277 -> entry3 -> node3 - // 7: -2084245305 -> entry0 -> node2 - // 8: -1127017311 -> entry1 -> node3 - // 9: 1305218356 -> entry5 -> node1 + // 0: -1962219106 -> entry0 -> node2 + // 1: 145569539 -> entry3 -> node3 + // 2: -1599101205 -> entry1 -> node3 + // 3: -165119218 -> entry3 -> node3 + // 4: 1142216720 -> entry4 -> node1 + // 5: 1347620135 -> entry5 -> node1 + // 6: 1232195252 -> entry5 -> node1 + // 7: 427886318 -> entry3 -> node3 + // 8: 1469878697 -> entry5 -> node1 + // 9: 296801082 -> entry3 -> node3 assertEquals(splitPlacementResult.getAssignments().keySet().size(), 3); - // node1: split 9 + // node1: split 4, 5, 6, 8 Collection node1Splits = splitPlacementResult.getAssignments().get(node1).stream().map(Split::getConnectorSplit).collect(toImmutableSet()); - // node2: split 1, 7 + // node2: split 0 Collection node2Splits = splitPlacementResult.getAssignments().get(node2).stream().map(Split::getConnectorSplit).collect(toImmutableSet()); - // node3: split 0, 2, 3, 4, 5, 6, 8 + // node3: split 1, 2, 3, 7, 9 Collection node3Splits = splitPlacementResult.getAssignments().get(node3).stream().map(Split::getConnectorSplit).collect(toImmutableSet()); // Scheduling the same splits on the same set of nodes should give the same assignment @@ -707,21 +708,21 @@ public void testAffinityAssignmentWithConsistentHashingWithVirtualNodes() nodeSelector = nodeScheduler.createNodeSelector(session, CONNECTOR_ID, 3); splitPlacementResult = nodeSelector.computeAssignments(splits, ImmutableList.of()); // hashing value for splits: - // 0: -1879591379 -> entry1 -> node4 - // 1: -2031875777 -> entry0 -> node2 - // 2: -163077544 -> entry4 -> node3 - // 3: 749129358 -> entry4 -> node3 - // 4: -1784631546 -> entry1 -> node4 - // 5: -118156056 -> entry4 -> node3 - // 6: 388471277 -> entry4 -> node3 - // 7: -2084245305 -> entry0 -> node2 - // 8: -1127017311 -> entry2 -> node3 - // 9: 1305218356 -> entry6 -> node4 - assertEquals(splitPlacementResult.getAssignments().keySet().size(), 3); - assertEquals(splitPlacementResult.getAssignments().get(node1), ImmutableSet.of()); + // 0: -1962219106 -> entry0 -> node2 + // 1: 145569539 -> entry4 -> node3 + // 2: -1599101205 -> entry2 -> node3 + // 3: -165119218 -> entry4 -> node3 + // 4: 1142216720 -> entry5 -> node1 + // 5: 1347620135 -> entry6 -> node4 + // 6: 1232195252 -> entry6 -> node4 + // 7: 427886318 -> entry4 -> node3 + // 8: 1469878697 -> entry6 -> node4 + // 9: 296801082 -> entry4 -> node3 + assertEquals(splitPlacementResult.getAssignments().keySet().size(), 4); + assertEquals(splitPlacementResult.getAssignments().get(node1).stream().map(Split::getConnectorSplit).map(ConnectorSplit::getSplitIdentifier).collect(toImmutableSet()), ImmutableSet.of(4)); assertEquals(splitPlacementResult.getAssignments().get(node2).stream().map(Split::getConnectorSplit).collect(toImmutableSet()), node2Splits); - assertEquals(splitPlacementResult.getAssignments().get(node3).stream().map(Split::getConnectorSplit).map(ConnectorSplit::getSplitIdentifier).collect(toImmutableSet()), ImmutableSet.of(2, 3, 5, 6, 8)); - assertEquals(splitPlacementResult.getAssignments().get(node4).stream().map(Split::getConnectorSplit).map(ConnectorSplit::getSplitIdentifier).collect(toImmutableSet()), ImmutableSet.of(0, 4, 9)); + assertEquals(splitPlacementResult.getAssignments().get(node3).stream().map(Split::getConnectorSplit).map(ConnectorSplit::getSplitIdentifier).collect(toImmutableSet()), ImmutableSet.of(1, 2, 3, 7, 9)); + assertEquals(splitPlacementResult.getAssignments().get(node4).stream().map(Split::getConnectorSplit).map(ConnectorSplit::getSplitIdentifier).collect(toImmutableSet()), ImmutableSet.of(5, 6, 8)); } @Test @@ -1315,7 +1316,7 @@ public NodeSelectionStrategy getNodeSelectionStrategy() @Override public List getPreferredNodes(NodeProvider nodeProvider) { - return nodeProvider.get(String.valueOf(scheduleIdentifierId), 1); + return nodeProvider.get(format("split%d", scheduleIdentifierId), 1); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestConsistentHashingNodeProvider.java b/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestConsistentHashingNodeProvider.java index 188c6e6c79acb..55063fcf6dedd 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestConsistentHashingNodeProvider.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestConsistentHashingNodeProvider.java @@ -16,6 +16,7 @@ import com.facebook.presto.client.NodeVersion; import com.facebook.presto.metadata.InternalNode; import com.facebook.presto.spi.HostAddress; +import com.google.common.collect.ImmutableSet; import org.testng.annotations.Test; import java.net.URI; @@ -26,7 +27,10 @@ import java.util.stream.IntStream; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.lang.String.format; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; import static org.testng.Assert.assertTrue; public class TestConsistentHashingNodeProvider @@ -39,10 +43,24 @@ public void testDistribution() Random random = new Random(); Map result = new HashMap<>(); for (int i = 0; i < 1_000_000; i++) { - HostAddress hostAddress = nodeProvider.get(format("split%d", random.nextInt()), 1).get(0); + List candidates = nodeProvider.get(format("split%d", random.nextInt()), 2); + assertNotEquals(candidates.get(1), candidates.get(0)); + HostAddress hostAddress = candidates.get(0); int count = result.getOrDefault(hostAddress, 0); result.put(hostAddress, count + 1); } assertTrue(result.values().stream().allMatch(count -> count >= 80000 && count <= 120000)); } + + @Test + public void testMultipleCandidates() + { + List nodes = IntStream.range(0, 10).mapToObj(i -> new InternalNode(format("other%d", i), URI.create(format("http://127.0.0.%d:100", i)), NodeVersion.UNKNOWN, false)).collect(toImmutableList()); + ConsistentHashingNodeProvider nodeProvider = ConsistentHashingNodeProvider.create(nodes, 1); + assertEquals(ImmutableSet.copyOf(nodeProvider.get("split1", 10)), nodes.stream().map(InternalNode::getHostAndPort).collect(toImmutableSet())); + assertEquals(ImmutableSet.copyOf(nodeProvider.get("split1", 11)), nodes.stream().map(InternalNode::getHostAndPort).collect(toImmutableSet())); + + ConsistentHashingNodeProvider nodeProviderWithWeight = ConsistentHashingNodeProvider.create(nodes, 100); + assertEquals(ImmutableSet.copyOf(nodeProvider.get("split1", 10)), nodes.stream().map(InternalNode::getHostAndPort).collect(toImmutableSet())); + } }