@@ -468,11 +468,12 @@ static final class StaticStrideScheduler {
468468 double sumWeight = 0 ;
469469 float maxWeight = 0 ;
470470 for (float weight : weights ) {
471- if (Math . abs ( weight - 0.0 ) < 0.0001 ) { // just equal to 0?
471+ if (weight < 0.0001 ) { // just equal to 0?
472472 numZeroWeightChannels ++;
473+ } else {
474+ sumWeight += weight ;
475+ maxWeight = Math .max (weight , maxWeight );
473476 }
474- sumWeight += weight ;
475- maxWeight = Math .max (weight , maxWeight );
476477 }
477478
478479 checkArgument (numChannels >= 1 , "Couldn't build scheduler: requires at least one weight" );
@@ -484,7 +485,7 @@ static final class StaticStrideScheduler {
484485 // scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly
485486 int [] scaledWeights = new int [numChannels ];
486487 for (int i = 0 ; i < numChannels ; i ++) {
487- if (Math . abs ( weights [i ]) < 0.0001 ) { // just equal to 0?
488+ if (weights [i ] < 0.0001 ) { // just equal to 0?
488489 scaledWeights [i ] = meanWeight ;
489490 } else {
490491 scaledWeights [i ] = (int ) Math .round (weights [i ] * scalingFactor );
@@ -512,7 +513,7 @@ int pickChannel() {
512513 // is this really that much more efficient than a 2d array?
513514 long weight = this .scaledWeights [backendIndex ];
514515 if ((weight * generation ) % K_MAX_WEIGHT < K_MAX_WEIGHT - weight ) {
515- // wow how does this work/how was it discovered
516+ // how does this work/how was it discovered
516517 continue ;
517518 }
518519 return backendIndex ;
0 commit comments