Skip to content

Commit e831863

Browse files
Spread layers more uniformly when using partition_uniform (#4053)
* update partition_uniform util function * formatting --------- Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent 1ba4098 commit e831863

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

deepspeed/runtime/utils.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import psutil
1515
import gc
1616
from math import sqrt
17-
from math import floor
1817
from bisect import bisect_left
1918
from packaging import version as pkg_version
2019

@@ -552,17 +551,23 @@ def prefix_sum_inc(weights):
552551

553552

554553
def partition_uniform(num_items, num_parts):
554+
import numpy
555555
parts = [0] * (num_parts + 1)
556556
# First check for the trivial edge case
557557
if num_items <= num_parts:
558558
for p in range(num_parts + 1):
559559
parts[p] = min(p, num_items)
560560
return parts
561561

562-
chunksize = floor(num_items / num_parts)
563-
for p in range(num_parts):
564-
parts[p] = min(chunksize * p, num_items)
565-
parts[num_parts] = num_items
562+
chunksize = num_items // num_parts
563+
residual = num_items - (chunksize * num_parts)
564+
565+
parts = numpy.arange(0, (num_parts + 1) * chunksize, chunksize)
566+
567+
for i in range(residual):
568+
parts[i + 1:] += 1
569+
parts = parts.tolist()
570+
566571
return parts
567572

568573

0 commit comments

Comments
 (0)