Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepspeed/ops/adam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
from .cpu_adam import DeepSpeedCPUAdam
from .fused_adam import FusedAdam
from .zenflow_cpu_adam import ZenFlowCPUAdam
from .zenflow_torch_adam import ZenFlowSelectiveAdamW
from .zenflow_torch_adam import ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3
269 changes: 236 additions & 33 deletions deepspeed/ops/adam/zenflow_torch_adam.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1861,6 +1861,7 @@ def _configure_zero_optimizer(self, optimizer):
overlap_comm=self.zero_overlap_comm(),
offload_optimizer_config=self.zero_offload_optimizer(),
offload_param_config=self.zero_offload_param(),
zenflow_config=self.zenflow_config(),
sub_group_size=self.zero_sub_group_size(),
offload_ratio=self.zero_partial_offload(),
mpu=self.mpu,
Expand Down
774 changes: 774 additions & 0 deletions deepspeed/runtime/zenflow/engine_stage3.py

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
module,
timers,
ds_config,
zenflow=False,
overlap_comm=True,
prefetch_bucket_size=50000000,
max_reuse_distance=1000000000,
Expand All @@ -115,6 +116,7 @@ def __init__(

self.module = module
self.timers = timers
self.zenflow = zenflow
self.dtype = list(module.parameters())[0].dtype
self.dp_process_group = dp_process_group
self.offload_device = None
Expand Down Expand Up @@ -472,6 +474,11 @@ def pre_sub_module_forward_function(self, sub_module):
param_coordinator.record_module(sub_module)
param_coordinator.fetch_sub_module(sub_module, forward=True)

if self.zenflow:
params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module)))
for param in params_to_fetch:
param.data = param.data.t() if len(param.ds_shape) != 1 else param.data

see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False)

@torch.no_grad()
Expand All @@ -480,6 +487,11 @@ def post_sub_module_forward_function(self, sub_module):
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
force=False)

if self.zenflow:
params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module)))
for param in params_to_fetch:
param.data = param.data.t() if len(param.ds_shape) != 1 else param.data

param_coordinator = self.get_param_coordinator()
param_coordinator.release_sub_module(sub_module, forward=True)

Expand All @@ -496,13 +508,23 @@ def pre_sub_module_backward_function(self, sub_module):
param_coordinator.record_module(sub_module)
param_coordinator.fetch_sub_module(sub_module, forward=False)

if self.zenflow:
params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module)))
for param in params_to_fetch:
param.data = param.data.t() if len(param.ds_shape) != 1 else param.data

@torch.no_grad()
def post_sub_module_backward_function(self, sub_module):
# assert sub_module.training, "backward pass is invalid for module in evaluation mode"
see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
force=False)

if self.zenflow:
params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module)))
for param in params_to_fetch:
param.data = param.data.t() if len(param.ds_shape) != 1 else param.data

self.get_param_coordinator().release_sub_module(sub_module, forward=False)

see_memory_usage(
Expand Down
61 changes: 58 additions & 3 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
import deepspeed.runtime.zenflow.engine_stage3 as zf_engine_stage3
from deepspeed.runtime.zero.utils import get_mapping_to_flat_buffer
from deepspeed.runtime.zero.offload_states import offload_adam_states, reload_adam_states
from deepspeed.ops.adam import DeepSpeedCPUAdam
Expand Down Expand Up @@ -160,6 +161,7 @@ def __init__(
overlap_comm=False,
offload_optimizer_config=None,
offload_param_config=None,
zenflow_config=None,
sub_group_size=1000000000000,
offload_ratio=0.0,
mpu=None,
Expand Down Expand Up @@ -226,6 +228,9 @@ def __init__(
self.partial_offload = offload_ratio
self.enable_sanity_checks = enable_sanity_checks

self.create_zenflow_hooks()
self._initialize_zenflow_stage3_prologue(module, zenflow_config)

#num of ranks in a ZeRO param partitioning group
self.zero_hpz_partition_size = zero_hpz_partition_size

Expand All @@ -241,6 +246,7 @@ def __init__(
module=module,
timers=timers,
ds_config=ds_config,
zenflow=self.zenflow,
overlap_comm=overlap_comm,
prefetch_bucket_size=prefetch_bucket_size,
max_reuse_distance=max_reuse_distance,
Expand Down Expand Up @@ -276,6 +282,8 @@ def __init__(
for i in range(1, len(self.optimizer.param_groups)):
self.backup_optimizer.add_param_group(self.optimizer.param_groups[i])

self._initialize_zenflow_stage3_epilogue(zenflow_config, overlap_comm)

self.module = module
self.elastic_checkpoint = elastic_checkpoint

Expand Down Expand Up @@ -476,11 +484,32 @@ def destroy(self):
print_rank_0("Removed grad acc hooks", force=False)
self.ipg_buckets.clear()

def create_zenflow_hooks(self):
from functools import partial
hook_names = [
"_initialize_zenflow_stage3_prologue",
"_initialize_zenflow_stage3_epilogue",
"zenflow_cpu_optimizer_step",
"_sync_selective_optimizer_lr",
"selective_optimizer_step",
"is_zenflow_select_boundary",
"update_selected_channels",
"_process_selected_fp32_groups_grad",
"zenflow_backward_prologue",
"zenflow_backward_epilogue",
"log_selective_optimizer_timers",
]

for name in hook_names:
fn = getattr(zf_engine_stage3, name)
setattr(self, name, partial(fn, self))

def initialize_ds_offload(
self,
module,
timers,
ds_config,
zenflow,
overlap_comm,
prefetch_bucket_size,
max_reuse_distance,
Expand All @@ -499,6 +528,7 @@ def initialize_ds_offload(
return DeepSpeedZeRoOffload(module=module,
timers=timers,
ds_config=ds_config,
zenflow=zenflow,
overlap_comm=overlap_comm,
prefetch_bucket_size=prefetch_bucket_size,
max_reuse_distance=max_reuse_distance,
Expand Down Expand Up @@ -738,6 +768,10 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups):
self.fp16_groups.append(sub_group)
self.fp16_partitioned_groups.append([param.ds_tensor for param in sub_group])

if self.zenflow:
for param in sub_group:
param.group_id = param_group_idx

# record sub group -> group mapping
self.sub_group_to_group_id[sub_group_idx] = param_group_idx

Expand Down Expand Up @@ -997,7 +1031,10 @@ def step_with_gradscaler(optimizer):
self.torch_autocast_gradscaler.step(optimizer)
self.torch_autocast_gradscaler.update()
else:
optimizer.step()
if not self.zenflow:
optimizer.step()
else:
self.zenflow_cpu_optimizer_step()

if self.offload_optimizer:
cur_device = self.subgroup_to_device[sub_group_id]
Expand Down Expand Up @@ -1267,8 +1304,14 @@ def __add_grad_to_ipg_bucket(self, param: Parameter) -> None:
# move the gradient to a contiguous buffer
with get_accelerator().stream(self.reduce_and_partition_stream):
# move the parameter's gradient to the contiguous flat buffer
new_grad_tensor = bucket.buffer.narrow(0, bucket.elements, param.grad.numel()).view_as(param.grad)
new_grad_tensor.copy_(param.grad, non_blocking=True)
if self.zenflow and len(param.ds_shape) != 1:
transposed_shape = param.grad.t().shape
new_grad_tensor = bucket.buffer.narrow(0, bucket.elements,
param.grad.numel()).view(transposed_shape)
new_grad_tensor.copy_(param.grad.t().contiguous(), non_blocking=True)
else:
new_grad_tensor = bucket.buffer.narrow(0, bucket.elements, param.grad.numel()).view_as(param.grad)
new_grad_tensor.copy_(param.grad, non_blocking=True)
if not get_accelerator().is_synchronized_device():
param.grad.record_stream(get_accelerator().current_stream())
param.grad.data = new_grad_tensor
Expand Down Expand Up @@ -1308,6 +1351,12 @@ def __reduce_and_partition_ipg_grads(self, communication_data_type: torch.dtype)
params_in_bucket.sort(key=lambda p: p.ds_id)
grad_partitions = self.__avg_scatter_grads(params_in_bucket, communication_data_type)

if self.is_zenflow_select_boundary():
self.update_selected_channels(params_in_bucket, grad_partitions)

if self.zenflow and self.micro_step >= self.full_warm_up_rounds:
self._process_selected_fp32_groups_grad(params_in_bucket, grad_partitions)

self.partition_grads(params_in_bucket, grad_partitions)

params_in_bucket.clear()
Expand Down Expand Up @@ -2272,6 +2321,9 @@ def backward(self, loss, retain_graph=False):
if self.swap_optimizer:
self.optimizer_swapper.pre_backward()

if self.zenflow:
self.zenflow_backward_prologue()

see_memory_usage("Before backward", force=False)

if self.custom_loss_scaler:
Expand All @@ -2282,6 +2334,9 @@ def backward(self, loss, retain_graph=False):
else:
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)

if self.zenflow:
self.zenflow_backward_epilogue()

if self.swap_optimizer:
self.optimizer_swapper.post_backward()

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/runtime/zenflow/test_zf.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run_training_distributed(self, config_dict):
model.destroy()


@pytest.mark.parametrize("stage", [1, 2])
@pytest.mark.parametrize("stage", [1, 2, 3])
@pytest.mark.parametrize("full_warm_up_rounds", [0, 3])
@pytest.mark.parametrize("offload_selective_optimizer", [True, False])
@pytest.mark.parametrize("select_strategy,select_interval,update_interval", [
Expand All @@ -93,7 +93,7 @@ def test_zenflow_single_gpu(self, stage, offload_selective_optimizer, select_str
tester.run_training_distributed(config_dict)


@pytest.mark.parametrize("stage", [1, 2])
@pytest.mark.parametrize("stage", [1, 2, 3])
@pytest.mark.parametrize("full_warm_up_rounds", [0, 3])
@pytest.mark.parametrize("offload_selective_optimizer", [True, False])
@pytest.mark.parametrize("select_strategy,select_interval,update_interval", [
Expand Down
Loading