Skip to content

Commit b5cf7b0

Browse files
implemented static stride scheduler class/algorithm (currently unused)
go/static-stride-scheduler
1 parent 0f2c43a commit b5cf7b0

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,37 @@ private void updateWeight() {
310310
}
311311
this.scheduler = scheduler;
312312
}
313+
314+
// check correctness of behavior
315+
private void updateWeightSSS() {
316+
int weightedChannelCount = 0;
317+
double avgWeight = 0;
318+
for (Subchannel value : list) {
319+
double newWeight = ((WrrSubchannel) value).getWeight();
320+
if (newWeight > 0) {
321+
avgWeight += newWeight;
322+
weightedChannelCount++;
323+
}
324+
}
325+
326+
if (weightedChannelCount >= 1) {
327+
avgWeight /= 1.0 * weightedChannelCount;
328+
} else {
329+
avgWeight = 1;
330+
}
331+
332+
List<Float> newWeights = new ArrayList<>();
333+
for (int i = 0; i < list.size(); i++) {
334+
WrrSubchannel subchannel = (WrrSubchannel) list.get(i);
335+
double newWeight = subchannel.getWeight();
336+
newWeights.add(
337+
i,
338+
newWeight > 0 ? (float) newWeight : (float) avgWeight); // check desired type (float?)
339+
}
340+
341+
StaticStrideScheduler ssScheduler = new StaticStrideScheduler(newWeights);
342+
this.ssScheduler = ssScheduler;
343+
}
313344

314345
@Override
315346
public String toString() {
@@ -432,6 +463,83 @@ int pick() {
432463
}
433464
}
434465

466+
// TODO: add javadocs comments
467+
@VisibleForTesting
468+
static final class StaticStrideScheduler {
469+
private Vector<Long> scaledWeights;
470+
private int sizeDivisor;
471+
private long sequence;
472+
private static final int K_MAX_WEIGHT = 65535; // uint16? can be uint8
473+
private static final long UINT32_MAX = 429967295L; // max value for uint32
474+
475+
StaticStrideScheduler(List<Float> weights) {
476+
int numChannels = weights.size();
477+
int numZeroWeightChannels = 0;
478+
double sumWeight = 0;
479+
float maxWeight = 0;
480+
for (float weight : weights) {
481+
if (weight == 0) {
482+
numZeroWeightChannels++;
483+
}
484+
sumWeight += weight;
485+
maxWeight = Math.max(weight, maxWeight);
486+
}
487+
488+
// checkArgument(numChannels <= 1, "Couldn't build scheduler: requires at least two weights");
489+
// checkArgument(numZeroWeightChannels == numChannels, "Couldn't build scheduler: only zero
490+
// weights");
491+
492+
double scalingFactor = K_MAX_WEIGHT / maxWeight;
493+
long meanWeight =
494+
Math.round(scalingFactor * sumWeight / (numChannels - numZeroWeightChannels));
495+
496+
// scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly
497+
Vector<Long> scaledWeights = new Vector<>(numChannels);
498+
// vectors are deprecated post Java 9?
499+
for (int i = 0; i < numChannels; i++) {
500+
if (weights.get(i) == 0) {
501+
scaledWeights.add(meanWeight);
502+
} else {
503+
scaledWeights.add(Math.round(weights.get(i) * scalingFactor));
504+
}
505+
}
506+
507+
this.scaledWeights = scaledWeights;
508+
this.sizeDivisor = numChannels; // why not just call it numChannels or numBackends
509+
this.sequence = (long) (Math.random() * UINT32_MAX); // why not initialize to [0,numChannels)
510+
}
511+
512+
private long nextSequence() {
513+
long sequence = this.sequence;
514+
this.sequence = (this.sequence + 1) % UINT32_MAX; // check wraparound logic
515+
return sequence;
516+
}
517+
518+
// is this getter necessary? (is it ever called outside of this class)
519+
public Vector<Long> getWeights() {
520+
return this.scaledWeights;
521+
}
522+
523+
private void addChannel() {}
524+
525+
// selects index of our next backend server
526+
int pickChannel() {
527+
while (true) {
528+
long sequence = this.nextSequence();
529+
int backendIndex = (int) sequence % this.sizeDivisor;
530+
long generation = sequence / this.sizeDivisor;
531+
// is this really that much more efficient than a 2d array?
532+
long weight = this.scaledWeights.get(backendIndex);
533+
long kOffset = K_MAX_WEIGHT / 2;
534+
long mod = (weight * generation + backendIndex * kOffset) % K_MAX_WEIGHT;
535+
if (mod < K_MAX_WEIGHT - weight) { // review this math
536+
continue;
537+
}
538+
return backendIndex;
539+
}
540+
}
541+
}
542+
435543
/** Holds the state of the object. */
436544
@VisibleForTesting
437545
static class ObjectState {

0 commit comments

Comments
 (0)