Skip to content
2 changes: 1 addition & 1 deletion deepspeed/runtime/swap_tensor/async_swapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _report_statistics(self, message):
if dist.get_rank() == 0:
element_size = torch.tensor([], dtype=self.dtype).element_size()
swapped_GB = (self.num_elements_swapped * element_size) / (1024**3)
logger.info(
logger.debug(
f'{message} num_elems = {self.num_elements_swapped}, {swapped_GB:5.2f} GB'
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def swap_in(self, params, async_op=True, swap_in_buffers=None):
f'Num inflight: params {len(self.inflight_params)}, buffers {len(self.inflight_swap_in_buffers)}, numel = {self.inflight_numel}',
force=True)
print_rank_0(
f'Num available: param {len(self.available_params)}, numel = {self.available_numel}',
f'Num available params: count = {len(self.available_params)}, ids = {self.available_params}, numel = {self.available_numel}',
force=True)

assert len(swap_in_paths) <= len(self.available_buffer_ids), f"Not enough buffers {len(self.available_buffer_ids)} for swapping {len(swap_in_paths)}"
Expand Down
13 changes: 12 additions & 1 deletion deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,23 @@ def __params_to_release(self,
params_to_release = set(p.ds_id for p in iter_params(submodule_to_release)
if not p.ds_persist)

# Problem: When prefetcher scans the param trace, it skips AVAILABLE params.
# This creates issues if those params are released before the skipped uses:
# 1) It hurts performance as the skipped uses are never prefetched.
# 2) For nvme params, we run out of swap buffers because the prefetch order
# diverges from the trace.
# Solution: Don't release params whose reuse was skipped by prefetch. This is
# possible because we detect such skips during prefetch and mark those params.
for param in iter_params(submodule_to_release):
if self.__most_recent_step_id_param_fetched_for[param] > step_id:
params_to_release.discard(param.ds_id)

# examine all modules within `max_reuse_dist_in_numel` of the current step,
# if we see any of the candidate parameters to be released reoccur while
# doing this, remove them from the set of parameters to release.
params_traversed = 0
for module in self.__submodule_order[step_id:]:
if params_traversed > self.__max_reuse_dist_in_numel:
if params_traversed >= self.__max_reuse_dist_in_numel:
break
for param in iter_params(module):
params_to_release.discard(param.ds_id)
Expand Down
47 changes: 28 additions & 19 deletions tests/unit/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,35 +486,42 @@ def __init__(

self.loss = L1Loss(reduction="none")

def forward(self, x: Tensor, y: Tensor, prefetching: bool) -> Dict[str, Tensor]:
def forward(self,
x: Tensor,
y: Tensor,
use_module_trace: bool,
param_prefetching: bool) -> Dict[str,
Tensor]:
_assert_partition_status(
self,
{
ZeroParamStatus.NOT_AVAILABLE,
ZeroParamStatus.INFLIGHT,
ZeroParamStatus.AVAILABLE
} if prefetching else {ZeroParamStatus.NOT_AVAILABLE})
} if use_module_trace else {ZeroParamStatus.NOT_AVAILABLE})

layerwise_expected_states = {
ZeroParamStatus.INFLIGHT if prefetching else ZeroParamStatus.NOT_AVAILABLE,
pre_layer_expected_states = {
ZeroParamStatus.INFLIGHT
if param_prefetching else ZeroParamStatus.NOT_AVAILABLE,
ZeroParamStatus.AVAILABLE,
}

_assert_partition_status(self.__layer1, layerwise_expected_states)
post_layer_expected_states = {
ZeroParamStatus.AVAILABLE
if param_prefetching else ZeroParamStatus.NOT_AVAILABLE,
}

_assert_partition_status(self.__layer1, pre_layer_expected_states)
hidden1 = self.__layer1(x)
_assert_partition_status(self.__layer1, {ZeroParamStatus.NOT_AVAILABLE})
_assert_partition_status(self.__layer1, post_layer_expected_states)

_assert_partition_status(self.__layer2, layerwise_expected_states)
_assert_partition_status(self.__layer2, pre_layer_expected_states)
hidden2 = self.__layer2(hidden1)
_assert_partition_status(self.__layer2, {ZeroParamStatus.NOT_AVAILABLE})
_assert_partition_status(self.__layer2, post_layer_expected_states)

_assert_partition_status(self.__layer3, layerwise_expected_states)
_assert_partition_status(self.__layer3, pre_layer_expected_states)
y_hat = self.__layer3(hidden2)
_assert_partition_status(self.__layer3,
{
ZeroParamStatus.AVAILABLE
if prefetching else ZeroParamStatus.NOT_AVAILABLE
})
_assert_partition_status(self.__layer3, post_layer_expected_states)

loss = self.loss(y_hat, y)

Expand All @@ -524,7 +531,7 @@ def forward(self, x: Tensor, y: Tensor, prefetching: bool) -> Dict[str, Tensor]:
ZeroParamStatus.NOT_AVAILABLE,
ZeroParamStatus.INFLIGHT,
ZeroParamStatus.AVAILABLE
} if prefetching else {ZeroParamStatus.NOT_AVAILABLE})
} if use_module_trace else {ZeroParamStatus.NOT_AVAILABLE})

return {
"hidden1": hidden1,
Expand All @@ -539,14 +546,14 @@ def forward(self, x: Tensor, y: Tensor, prefetching: bool) -> Dict[str, Tensor]:
@pytest.mark.parametrize("contiguous_gradients", [True, False])
@pytest.mark.parametrize("offload_optimizer", [True, False])
@pytest.mark.parametrize("zero_grad", [True, False])
@pytest.mark.parametrize("iteration", list(range(1)))
@pytest.mark.parametrize("prefetching", [True, False])
def test_zero3_param_partitioning_base(
param_persistence_threshold: int,
fp16_enabled: bool,
contiguous_gradients: bool,
offload_optimizer: bool,
zero_grad: bool,
iteration: int,
prefetching: bool,
) -> None:
@distributed_test(world_size=[2])
def _test_zero3_param_partitioning():
Expand All @@ -557,14 +564,15 @@ def _test_zero3_param_partitioning():
n = 5
weights = [Parameter(torch.zeros((m, n), dtype=torch.float32)) for _ in range(3)]
model = EltwiseMultiplicationTestNetwork(*weights)

prefetch_bucket_size = sum([p.numel() for p in model.parameters(recurse=True)])
cfg = {
"train_micro_batch_size_per_gpu": 1,
"zero_optimization": {
"stage": 3,
"stage3_max_reuse_distance": 0,
"stage3_param_persistence_threshold": param_persistence_threshold,
"contiguous_gradients": contiguous_gradients,
"stage3_prefetch_bucket_size": prefetch_bucket_size if prefetching else 0
},
"optimizer": {
"type": "Adam",
Expand Down Expand Up @@ -672,7 +680,8 @@ def create_tensor(vals, dtype: torch.dtype = None) -> Tensor:
n),
dtype=torch.float16 if fp16_enabled else torch.float32,
device=ds_engine.device),
prefetching=train_iter > 0,
use_module_trace=train_iter > 0,
param_prefetching=prefetching and train_iter > 0,
)
assert torch.allclose(activations["hidden1"], expected_hidden1)
assert torch.allclose(activations["hidden2"], expected_hidden2)
Expand Down