|
| 1 | +/* |
| 2 | + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"). |
| 5 | + * You may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +package software.amazon.jdbc; |
| 18 | + |
| 19 | +import java.sql.SQLException; |
| 20 | +import java.util.Comparator; |
| 21 | +import java.util.HashMap; |
| 22 | +import java.util.List; |
| 23 | +import java.util.Map; |
| 24 | +import java.util.Properties; |
| 25 | +import java.util.Random; |
| 26 | +import java.util.concurrent.locks.ReentrantLock; |
| 27 | +import java.util.regex.Matcher; |
| 28 | +import java.util.regex.Pattern; |
| 29 | +import java.util.stream.Collectors; |
| 30 | +import org.checkerframework.checker.nullness.qual.NonNull; |
| 31 | +import org.checkerframework.checker.nullness.qual.Nullable; |
| 32 | +import software.amazon.jdbc.hostavailability.HostAvailability; |
| 33 | +import software.amazon.jdbc.util.Messages; |
| 34 | + |
| 35 | +public class WeightedRandomHostSelector implements HostSelector { |
| 36 | + public static final AwsWrapperProperty WEIGHTED_RANDOM_HOST_WEIGHT_PAIRS = new AwsWrapperProperty( |
| 37 | + "weightedRandomHostWeightPairs", null, |
| 38 | + "Comma separated list of database host-weight pairs in the format of `<host>:<weight>`."); |
| 39 | + public static final String STRATEGY_WEIGHTED_RANDOM = "weightedRandom"; |
| 40 | + static final int DEFAULT_WEIGHT = 1; |
| 41 | + static final Pattern HOST_WEIGHT_PAIRS_PATTERN = |
| 42 | + Pattern.compile("((?<host>[^:/?#]*):(?<weight>[0-9]*))"); |
| 43 | + |
| 44 | + private Map<String, Integer> cachedHostWeightMap; |
| 45 | + private String cachedHostWeightMapString; |
| 46 | + private Random random; |
| 47 | + |
| 48 | + private final ReentrantLock lock = new ReentrantLock(); |
| 49 | + |
| 50 | + public WeightedRandomHostSelector() { |
| 51 | + this(new Random()); |
| 52 | + } |
| 53 | + |
| 54 | + public WeightedRandomHostSelector(final Random random) { |
| 55 | + this.random = random; |
| 56 | + } |
| 57 | + |
| 58 | + public HostSpec getHost( |
| 59 | + @NonNull List<HostSpec> hosts, |
| 60 | + @NonNull HostRole role, |
| 61 | + @Nullable Properties props) throws SQLException { |
| 62 | + |
| 63 | + final Map<String, Integer> hostWeightMap = |
| 64 | + this.getHostWeightPairMap(WEIGHTED_RANDOM_HOST_WEIGHT_PAIRS.getString(props)); |
| 65 | + |
| 66 | + // Get and check eligible hosts |
| 67 | + final List<HostSpec> eligibleHosts = hosts.stream() |
| 68 | + .filter(hostSpec -> |
| 69 | + role.equals(hostSpec.getRole()) && hostSpec.getAvailability().equals(HostAvailability.AVAILABLE)) |
| 70 | + .sorted(Comparator.comparing(HostSpec::getHost)) |
| 71 | + .collect(Collectors.toList()); |
| 72 | + |
| 73 | + if (eligibleHosts.isEmpty()) { |
| 74 | + throw new SQLException(Messages.get("HostSelector.noHostsMatchingRole", new Object[] {role})); |
| 75 | + } |
| 76 | + |
| 77 | + final Map<String, NumberRange> hostWeightRangeMap = new HashMap<>(); |
| 78 | + int counter = 1; |
| 79 | + for (HostSpec host : eligibleHosts) { |
| 80 | + if (!hostWeightMap.containsKey(host.getHost())) { |
| 81 | + continue; |
| 82 | + } |
| 83 | + final int hostWeight = hostWeightMap.get(host.getHost()); |
| 84 | + if (hostWeight > 0) { |
| 85 | + final int rangeStart = counter; |
| 86 | + final int rangeEnd = counter + hostWeight - 1; |
| 87 | + hostWeightRangeMap.put(host.getHost(), new NumberRange(rangeStart, rangeEnd)); |
| 88 | + counter = counter + hostWeight; |
| 89 | + } else { |
| 90 | + hostWeightRangeMap.put(host.getHost(), new NumberRange(counter, counter)); |
| 91 | + counter++; |
| 92 | + } |
| 93 | + } |
| 94 | + |
| 95 | + if (this.random == null) { |
| 96 | + this.random = new Random(); |
| 97 | + } |
| 98 | + int randomInt = this.random.nextInt(counter); |
| 99 | + |
| 100 | + // Check random number is in host weight range map |
| 101 | + for (final HostSpec host : eligibleHosts) { |
| 102 | + NumberRange range = hostWeightRangeMap.get(host.getHost()); |
| 103 | + if (range != null && range.isInRange(randomInt)) { |
| 104 | + return host; |
| 105 | + } |
| 106 | + } |
| 107 | + |
| 108 | + throw new SQLException(Messages.get("HostSelector.weightedRandomUnableToGetHost", new Object[] {role})); |
| 109 | + } |
| 110 | + |
| 111 | + private Map<String, Integer> getHostWeightPairMap(final String hostWeightMapString) throws SQLException { |
| 112 | + try { |
| 113 | + lock.lock(); |
| 114 | + if (this.cachedHostWeightMapString != null |
| 115 | + && this.cachedHostWeightMapString.trim().equals(hostWeightMapString.trim()) |
| 116 | + && this.cachedHostWeightMap != null |
| 117 | + && !this.cachedHostWeightMap.isEmpty()) { |
| 118 | + return this.cachedHostWeightMap; |
| 119 | + } |
| 120 | + |
| 121 | + final Map<String, Integer> hostWeightMap = new HashMap<>(); |
| 122 | + if (hostWeightMapString == null || hostWeightMapString.trim().isEmpty()) { |
| 123 | + return hostWeightMap; |
| 124 | + } |
| 125 | + final String[] hostWeightPairs = hostWeightMapString.split(","); |
| 126 | + for (final String hostWeightPair : hostWeightPairs) { |
| 127 | + final Matcher matcher = HOST_WEIGHT_PAIRS_PATTERN.matcher(hostWeightPair); |
| 128 | + if (!matcher.matches()) { |
| 129 | + throw new SQLException(Messages.get("HostSelector.weightedRandomInvalidHostWeightPairs")); |
| 130 | + } |
| 131 | + |
| 132 | + final String hostName = matcher.group("host").trim(); |
| 133 | + final String hostWeight = matcher.group("weight").trim(); |
| 134 | + if (hostName.isEmpty() || hostWeight.isEmpty()) { |
| 135 | + throw new SQLException(Messages.get("HostSelector.weightedRandomInvalidHostWeightPairs")); |
| 136 | + } |
| 137 | + |
| 138 | + try { |
| 139 | + final int weight = Integer.parseInt(hostWeight); |
| 140 | + if (weight < DEFAULT_WEIGHT) { |
| 141 | + throw new SQLException(Messages.get("HostSelector.weightedRandomInvalidHostWeightPairs")); |
| 142 | + } |
| 143 | + hostWeightMap.put(hostName, weight); |
| 144 | + } catch (NumberFormatException e) { |
| 145 | + throw new SQLException(Messages.get("HostSelector.roundRobinInvalidHostWeightPairs")); |
| 146 | + } |
| 147 | + } |
| 148 | + this.cachedHostWeightMap = hostWeightMap; |
| 149 | + this.cachedHostWeightMapString = hostWeightMapString; |
| 150 | + return hostWeightMap; |
| 151 | + } finally { |
| 152 | + lock.unlock(); |
| 153 | + } |
| 154 | + } |
| 155 | + |
| 156 | + private static class NumberRange { |
| 157 | + private int start; |
| 158 | + private int end; |
| 159 | + |
| 160 | + public NumberRange(int start, int end) { |
| 161 | + this.start = start; |
| 162 | + this.end = end; |
| 163 | + } |
| 164 | + |
| 165 | + public boolean isInRange(int value) { |
| 166 | + return start <= value && value <= end; |
| 167 | + } |
| 168 | + } |
| 169 | +} |
0 commit comments