Skip to content

Commit ba1a0b7

Browse files
fixed pickChannel() by removing kOffset, fixed atomic integer
1 parent acd3425 commit ba1a0b7

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -458,18 +458,18 @@ int pick() {
458458
@VisibleForTesting
459459
static final class StaticStrideScheduler {
460460
private final long[] scaledWeights;
461-
private final AtomicInteger sizeDivisor;
462-
private long sequence;
461+
private final int sizeDivisor;
462+
private final AtomicInteger sequence;
463463
private static final int K_MAX_WEIGHT = 65535; // uint16? can be uint8
464-
private static final long UINT32_MAX = 429967295L; // max value for uint32
464+
private static final int UINT32_MAX = 429967295; // max value for uint32
465465

466466
StaticStrideScheduler(float[] weights) {
467467
int numChannels = weights.length;
468468
int numZeroWeightChannels = 0;
469469
double sumWeight = 0;
470470
float maxWeight = 0;
471471
for (float weight : weights) {
472-
if (Math.abs(weight - 0.0) < 0.0001) {
472+
if (Math.abs(weight - 0.0) < 0.0001) { // just equal to 0?
473473
numZeroWeightChannels++;
474474
}
475475
sumWeight += weight;
@@ -479,41 +479,44 @@ static final class StaticStrideScheduler {
479479
checkArgument(numChannels >= 1, "Couldn't build scheduler: requires at least one weight");
480480

481481
double scalingFactor = K_MAX_WEIGHT / maxWeight;
482+
if (numZeroWeightChannels == numChannels) {
483+
System.out.println("ALL 0 WEIGHT CHANNELS");
484+
}
482485
long meanWeight = numZeroWeightChannels == numChannels ? 1 :
483486
Math.round(scalingFactor * sumWeight / (numChannels - numZeroWeightChannels));
484487

485488
// scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly
486489
long[] scaledWeights = new long[numChannels];
487490
for (int i = 0; i < numChannels; i++) {
488-
if (weights[i] == 0) {
491+
if (Math.abs(weights[i]) < 0.0001) { // just equal to 0?
489492
scaledWeights[i] = meanWeight;
490493
} else {
491494
scaledWeights[i] = Math.round(weights[i] * scalingFactor);
492495
}
493496
}
494497

495498
this.scaledWeights = scaledWeights;
496-
this.sizeDivisor = new AtomicInteger(numChannels); // why not call numChannels or numBackends
497-
this.sequence = (long) (Math.random() * UINT32_MAX); // why not initialize to [0,numChannels)
499+
this.sizeDivisor = numChannels;
500+
// this.sequence = new AtomicInteger((int) (Math.random() * UINT32_MAX));
501+
this.sequence = new AtomicInteger(0);
502+
// optimization: isn't sequence guaranteed to be in the first generation?
503+
// failing test cases when initialized to non-zero value
498504
}
499505

500-
private long nextSequence() {
501-
long sequence = this.sequence;
502-
this.sequence = (this.sequence + 1) % UINT32_MAX; // check wraparound logic
503-
return sequence;
506+
private int nextSequence() {
507+
return this.sequence.getAndUpdate(seq -> ((seq + 1) % UINT32_MAX));
504508
}
505509

506510
// selects index of our next backend server
507511
int pickChannel() {
508512
while (true) {
509-
long sequence = this.nextSequence();
510-
int backendIndex = (int) sequence % this.sizeDivisor.get();
511-
long generation = sequence / this.sizeDivisor.get();
513+
int sequence = this.nextSequence();
514+
int backendIndex = sequence % this.sizeDivisor;
515+
long generation = sequence / this.sizeDivisor;
512516
// is this really that much more efficient than a 2d array?
513517
long weight = this.scaledWeights[backendIndex];
514-
long kOffset = K_MAX_WEIGHT / 2;
515-
long mod = (weight * generation + backendIndex * kOffset) % K_MAX_WEIGHT;
516-
if (mod < K_MAX_WEIGHT - weight) { // review this math
518+
if ((weight * generation) % K_MAX_WEIGHT < K_MAX_WEIGHT - weight) {
519+
// wow how does this work/how was it discovered
517520
continue;
518521
}
519522
return backendIndex;

0 commit comments

Comments
 (0)