From 3309b49d8c177cd491b35fadb3ed17be09adcbd1 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Wed, 25 Jun 2025 02:32:32 -0400 Subject: [PATCH 01/26] Add ZenFlow optimizers (zero stage 1&2) for ZeRO integration - Add ZenFlowCPUAdam and ZenFlowSelectiveAdamW for selective updates - Implement ZenFlowZeroOptimizer and its parallel variant - Support gradient offloading and communication overlap - Implement (un)flatten ops for column-major layout Signed-off-by: Tingfeng Lan Co-authored-by: Yusen Wu --- deepspeed/ops/adam/__init__.py | 2 + deepspeed/ops/adam/zenflow_cpu_adam.py | 138 +++ deepspeed/ops/adam/zenflow_torch_adam.py | 883 +++++++++++++++++ deepspeed/runtime/zero/stage_1_and_2.py | 33 +- deepspeed/runtime/zero/zenflow/__init__.py | 4 + .../zero/zenflow/zenflow_stage_1_and_2.py | 929 ++++++++++++++++++ .../runtime/zero/zenflow/zenflow_utils.py | 42 + 7 files changed, 2024 insertions(+), 7 deletions(-) create mode 100644 deepspeed/ops/adam/zenflow_cpu_adam.py create mode 100644 deepspeed/ops/adam/zenflow_torch_adam.py create mode 100644 deepspeed/runtime/zero/zenflow/__init__.py create mode 100644 deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py create mode 100644 deepspeed/runtime/zero/zenflow/zenflow_utils.py diff --git a/deepspeed/ops/adam/__init__.py b/deepspeed/ops/adam/__init__.py index a29bb9447d01..039106a1fb84 100755 --- a/deepspeed/ops/adam/__init__.py +++ b/deepspeed/ops/adam/__init__.py @@ -5,3 +5,5 @@ from .cpu_adam import DeepSpeedCPUAdam from .fused_adam import FusedAdam +from .zenflow_cpu_adam import ZenFlowCPUAdam +from .zenflow_torch_adam import ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3 diff --git a/deepspeed/ops/adam/zenflow_cpu_adam.py b/deepspeed/ops/adam/zenflow_cpu_adam.py new file mode 100644 index 000000000000..dedebc53559f --- /dev/null +++ b/deepspeed/ops/adam/zenflow_cpu_adam.py @@ -0,0 +1,138 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.ops.adam import DeepSpeedCPUAdam +import torch + + +class ZenFlowCPUAdam(DeepSpeedCPUAdam): + + def __init__(self, *args, overlap_step=False, **kwargs): + super(ZenFlowCPUAdam, self).__init__(*args, **kwargs) + self.overlap_step = overlap_step + if self.overlap_step: + print("ZenFlowCPUAdam initialized with overlap step.") + self.step = self._sequential_step + else: + print("ZenFlowCPUAdam initialized with normal step.") + self.step = self._parallel_step + + @torch.no_grad() + def _sequential_step(self, step_id, closure=None): + """Update the model parameters. + + .. note:: + This method will be called internally by ZeRO-Offload. DeepSpeed + users should still use ``engine.step()`` as shown in the + `Getting Started + `_ guide. + + Args: + closure (callable, optional): closure to compute the loss. + Defaults to ``None``. + + Returns: + loss: if ``closure`` is provided. Otherwise ``None``. + """ + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # intended device for step + device = torch.device('cpu') + + for group_id, group in enumerate(self.param_groups): + for param_id, p in enumerate(group['params']): + + if p.grad is None: + continue + + assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \ + "sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config." + + state = self.state[p] + # State initialization + if len(state) == 0: + #print(f'group {group_id} param {param_id} = {p.numel()}') + state['step'] = 0 + + #use full precision by default unless self.fp32_optimizer_states is off + state_dtype = torch.float if self.fp32_optimizer_states else p.dtype + + # gradient momentums + state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device) + #memory_format=torch.preserve_format) + # gradient variances + state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device) + #memory_format=torch.preserve_format) + + state['step'] = step_id + beta1, beta2 = group['betas'] + self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'], + group['weight_decay'], group['bias_correction'], p.data, p.grad.data, + state['exp_avg'], state['exp_avg_sq']) + return loss + + @torch.no_grad() + def _parallel_step(self, step_id, now_state, group_info, closure=None): + """Update the model parameters. + + .. note:: + This method will be called internally by ZeRO-Offload. DeepSpeed + users should still use ``engine.step()`` as shown in the + `Getting Started + `_ guide. + + Args: + closure (callable, optional): closure to compute the loss. + Defaults to ``None``. + + Returns: + loss: if ``closure`` is provided. Otherwise ``None``. + """ + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # intended device for step + device = torch.device('cpu') + + stale_param = None + + for group_id, group in enumerate(self.param_groups): + for param_id, p in enumerate(group['params']): + assert p.data.is_shared(), "param.data must be in shared memory" + if not hasattr(p, 'overlap_grad'): + continue + + assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \ + "sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config." + + state = self.state[p] + # State initialization + if len(state) == 0: + #print(f'group {group_id} param {param_id} = {p.numel()}') + # print("creating", flush=True) + state['step'] = 0 + + #use full precision by default unless self.fp32_optimizer_states is off + state_dtype = torch.float if self.fp32_optimizer_states else p.dtype + exp_avg = torch.zeros_like(p.data, dtype=state_dtype, device=device) + exp_avg_sq = torch.zeros_like(p.data, dtype=state_dtype, device=device) + state['exp_avg'] = [exp_avg, exp_avg.clone()] + state['exp_avg_sq'] = [exp_avg_sq, exp_avg_sq.clone()] + + state['step'] = step_id + beta1, beta2 = group_info['betas'] + self.ds_opt_adam.adam_update(self.opt_id, state['step'], group_info['lr'], beta1, beta2, + group_info['eps'], group_info['weight_decay'], + group_info['bias_correction'], p.data, p.overlap_grad[now_state].data, + state['exp_avg'][now_state], state['exp_avg_sq'][now_state]) + p.stale_param.data.copy_(p.data.clone()) + return loss diff --git a/deepspeed/ops/adam/zenflow_torch_adam.py b/deepspeed/ops/adam/zenflow_torch_adam.py new file mode 100644 index 000000000000..f682427ec1de --- /dev/null +++ b/deepspeed/ops/adam/zenflow_torch_adam.py @@ -0,0 +1,883 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from typing import cast, List, Optional, Tuple, Union +from torch import Tensor + +from torch.optim.optimizer import ( + _default_to_fused_or_foreach, + _disable_dynamo_if_unsupported, + _get_capturable_supported_devices, + _get_value, + _stack_if_compiling, + _view_as_real, + DeviceDict, + Optimizer, +) + + +class ZenFlowSelectiveAdamW(torch.optim.AdamW): + + def __init__(self, *args, offload=False, bucket_size=5e8, **kwargs): + super(ZenFlowSelectiveAdamW, self).__init__(*args, **kwargs) + + if offload: + self.step = self._step_with_offload + self.temp_copy_param = self._temp_copy_param_with_offload + self.group_step = self._group_step_with_offload + self.bucket_size = bucket_size + else: + self.step = self._step_without_offload + self.temp_copy_param = self._temp_copy_param_without_offload + self.group_step = self._group_step_without_offload + + @torch.no_grad() + def _temp_copy_param_with_offload(self, group_to_paramlist): + for group_id, params in group_to_paramlist.items(): + for param in params: + if hasattr(param, "selected_grad"): + temp_selected_param = param.data[:, param.selected_indices].clone().detach() if len( + param.shape) != 1 else param.data.clone().detach() + param.temp_selected_param = temp_selected_param.cpu() + + @torch.no_grad() + def _temp_copy_param_without_offload(self, group_to_paramlist): + for group_id, params in group_to_paramlist.items(): + for param in params: + if hasattr(param, "selected_grad"): + param.temp_selected_param = param.data[:, param.selected_indices].clone().detach() if len( + param.shape) != 1 else param.data.clone().detach() + + def copy_mv_from_cpu(self, params): + for param in params: + param.exp_avg = param.exp_avg_cpu_data.to(param.device, non_blocking=True) + param.exp_avg_sq = param.exp_avg_sq_cpu_data.to(param.device, non_blocking=True) + + def copy_mv_to_cpu(self, params): + for param in params: + param.exp_avg_cpu_data.copy_(param.exp_avg.data, non_blocking=True) + param.exp_avg_sq_cpu_data.copy_(param.exp_avg_sq.data, non_blocking=True) + param.exp_avg = None + param.exp_avg_sq = None + + def clear_selected_mv(self): + print("Zenflow: clearing selective optimizer states...") + for group in self.param_groups: + for param in group['params']: + state = self.state.setdefault(param, {}) + if len(state) == 0: + continue + if self.offload: + param.exp_avg_cpu_data.zero_() + param.exp_avg_sq_cpu_data.zero_() + else: + state["exp_avg"].zero_() + state["exp_avg_sq"].zero_() + + @torch.no_grad() + def _step_without_offload(self): + for group in self.param_groups: + + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + max_exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] + amsgrad: bool = group["amsgrad"] + beta1, beta2 = cast(Tuple[float, float], group["betas"]) + + for param in group["params"]: + if hasattr(param, "selected_grad"): + selected_param = param.data[:, param.selected_indices] if len(param.shape) != 1 else param.data + if hasattr(param, 'temp_selected_param') and param.temp_selected_param is not None: + selected_param.copy_(param.temp_selected_param) + + state = self.state.setdefault(param, {}) + if len(state) == 0: + state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device) + state["exp_avg"] = torch.zeros_like(selected_param) + state["exp_avg_sq"] = torch.zeros_like(selected_param) + if amsgrad: + state["max_exp_avg_sq"] = torch.zeros_like(selected_param) + + params_with_grad.append(selected_param) + grads.append(param.selected_grad) + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + if amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + state_steps.append(state["step"]) + + adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=False, + ) + + for param in group["params"]: + if hasattr(param, "temp_selected_param"): + param.temp_selected_param = None + param.selected_grad = None + + @torch.no_grad() + def _step_with_offload(self): + for group_id, group in enumerate(self.param_groups): + params = group["params"] + + bucket = [] + bucket_numel = 0 + + def flush_bucket(): + if not bucket: + return + for param in bucket: + if hasattr(param, "temp_selected_param") and param.temp_selected_param is not None: + selected_param = param.data[:, param.selected_indices] if len(param.shape) != 1 else param.data + temp_selected_param = param.temp_selected_param.to(param.device, non_blocking=True) + selected_param.copy_(temp_selected_param) + param.temp_selected_param = None + + self.group_step({group_id: bucket}) + bucket.clear() + + for param in params: + if hasattr(param, "selected_grad"): + bucket.append(param) + bucket_numel += param.numel() + if bucket_numel >= self.bucket_size: + flush_bucket() + bucket_numel = 0 + + flush_bucket() + + @torch.no_grad() + def _group_step_without_offload(self, group_to_paramlist): + for group_id, params in group_to_paramlist.items(): + group = self.param_groups[group_id] + + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + max_exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] + amsgrad: bool = group["amsgrad"] + beta1, beta2 = cast(Tuple[float, float], group["betas"]) + + for param in params: + if hasattr(param, "selected_grad"): + selected_param = param.data[:, param.selected_indices] if len(param.shape) != 1 else param.data + + state = self.state.setdefault(param, {}) + if len(state) == 0: + state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device) + state["exp_avg"] = torch.zeros_like(selected_param) + state["exp_avg_sq"] = torch.zeros_like(selected_param) + if amsgrad: + state["max_exp_avg_sq"] = torch.zeros_like(selected_param) + + params_with_grad.append(selected_param) + grads.append(param.selected_grad) + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + if amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + state_steps.append(state["step"]) + + adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=False, + ) + + for param in params: + param.selected_grad = None + + @torch.no_grad() + def _group_step_with_offload(self, group_to_paramlist): + for group_id, params in group_to_paramlist.items(): + group = self.param_groups[group_id] + + self.copy_mv_from_cpu(params) + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + max_exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] + amsgrad: bool = group["amsgrad"] + beta1, beta2 = cast(Tuple[float, float], group["betas"]) + + for param in params: + if hasattr(param, "selected_grad"): + selected_param = param.data[:, param.selected_indices] if len(param.shape) != 1 else param.data + + state = self.state.setdefault(param, {}) + if len(state) == 0: + state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device) + if amsgrad: + state["max_exp_avg_sq"] = torch.zeros_like(selected_param) + + params_with_grad.append(selected_param) + grads.append(param.selected_grad) + exp_avgs.append(param.exp_avg.view_as(selected_param)) + exp_avg_sqs.append(param.exp_avg_sq.view_as(selected_param)) + if amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + state_steps.append(state["step"]) + + adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=False, + ) + + self.copy_mv_to_cpu(params) + + for param in params: + param.selected_grad = None + + +class ZenFlowSelectiveAdamW_stage3(torch.optim.AdamW): + + def __init__(self, *args, **kwargs): + super(ZenFlowSelectiveAdamW_stage3, self).__init__(*args, **kwargs) + + @torch.no_grad() + def temp_copy_param(self, paramlist): + for param in paramlist: + if hasattr(param, "selected_grad"): + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + + if num_row != 1: + param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, param.complete_numel).view( + param.complete_numel // num_row, num_row) + param.temp_selected_param = param_2d[param.selected_indices, :].clone().detach() + else: + param.temp_selected_param = param.ds_tensor.data.clone().detach() + + def clear_selected_mv(self): + print("clearing...") + for group in self.param_groups: + for param in group['params']: + state = self.state.setdefault(param, {}) + if len(state) == 0: + continue + state["exp_avg"].zero_() + state["exp_avg_sq"].zero_() + + @torch.no_grad() + def step(self): + for group in self.param_groups: + + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + max_exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] + amsgrad: bool = group["amsgrad"] + beta1, beta2 = cast(Tuple[float, float], group["betas"]) + + for param in group["params"]: + if hasattr(param, "selected_grad"): + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + if num_row != 1: + param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, + param.complete_numel).view( + param.complete_numel // num_row, num_row) + selected_param = param_2d[param.selected_indices, :] + else: + selected_param = param.ds_tensor.data + if hasattr(param, 'temp_selected_param') and param.temp_selected_param is not None: + selected_param.copy_(param.temp_selected_param) + + state = self.state.setdefault(param, {}) + if len(state) == 0: + state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device) + state["exp_avg"] = torch.zeros_like(selected_param) + state["exp_avg_sq"] = torch.zeros_like(selected_param) + if amsgrad: + state["max_exp_avg_sq"] = torch.zeros_like(selected_param) + + params_with_grad.append(selected_param) + grads.append(param.selected_grad) + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + if amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + state_steps.append(state["step"]) + adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=False, + ) + + for param in group["params"]: + if hasattr(param, "temp_selected_param"): + param.temp_selected_param = None + param.selected_grad = None + + @torch.no_grad() + def group_step(self, paramlist): + + group_to_paramlist = {} + for param in paramlist: + group_id = param.group_id + if group_id not in group_to_paramlist: + group_to_paramlist[group_id] = [] + group_to_paramlist[group_id].append(param) + + for group_id in sorted(group_to_paramlist.keys()): + params = group_to_paramlist[group_id] + group = self.param_groups[group_id] + + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + max_exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] + amsgrad: bool = group["amsgrad"] + beta1, beta2 = cast(Tuple[float, float], group["betas"]) + + for param in params: + if hasattr(param, "selected_grad"): + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + + if num_row != 1: + param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, + param.complete_numel).view( + param.complete_numel // num_row, num_row) + selected_param = param_2d[param.selected_indices, :] + else: + selected_param = param.ds_tensor.data + + state = self.state.setdefault(param, {}) + if len(state) == 0: + state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device) + state["exp_avg"] = torch.zeros_like(selected_param) + state["exp_avg_sq"] = torch.zeros_like(selected_param) + if amsgrad: + state["max_exp_avg_sq"] = torch.zeros_like(selected_param) + + params_with_grad.append(selected_param) + grads.append(param.selected_grad) + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + if amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + state_steps.append(state["step"]) + + adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=False, + ) + + for param in params: + param.selected_grad = None + + +def _single_tensor_adamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: Union[Tensor, float], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, + has_complex: bool, +): + assert grad_scale is None and found_inf is None + + if torch.jit.is_scripting(): + # this assert is due to JIT being dumb and not realizing that the ops below + # have overloads to handle both float and Tensor lrs, so we just assert it's + # a float since most people using JIT are using floats + assert isinstance(lr, float) + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step_t.device.type and param.device.type in capturable_supported_devices + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + exp_avg_sq = torch.view_as_real(exp_avg_sq) + if amsgrad: + max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i]) + param = torch.view_as_real(param) + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.lerp_(grad, 1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + if capturable or differentiable: + step = step_t + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + step_size_neg = step_size.neg() + + bias_correction2_sqrt = bias_correction2.sqrt() + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + if differentiable: + max_exp_avg_sq = max_exp_avg_sqs[i].clone() + else: + max_exp_avg_sq = max_exp_avg_sqs[i] + + max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq)) + + # Uses the max. for normalizing running avg. of gradient + # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write + # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor) + denom = (max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg) + else: + denom = (exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg) + + param.addcdiv_(exp_avg, denom) + else: + step = _get_value(step_t) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = bias_correction2**0.5 + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) + + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps) + else: + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + + # Lastly, switch back to complex view + if amsgrad and torch.is_complex(params[i]): + max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i]) + + +def _multi_tensor_adamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: Union[Tensor, float], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, + has_complex: bool, +): + if len(params) == 0: + return + + if isinstance(lr, Tensor) and not capturable: + raise RuntimeError("lr as a Tensor is not supported for capturable=False and foreach=True") + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices(supports_xla=False) + assert all( + p.device.type == step.device.type and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + assert not differentiable, "_foreach ops don't support autograd" + + assert grad_scale is None and found_inf is None + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] + ) + for ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs_, + device_state_steps_, + ), _ in grouped_tensors.values(): + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_exp_avgs = cast(List[Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) + device_state_steps = cast(List[Tensor], device_state_steps_) + + if has_complex: + if amsgrad: + device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) + _view_as_real( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, + ) + else: + _view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: + torch._foreach_add_(device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0) + else: + torch._foreach_add_(device_state_steps, 1) + + # Perform stepweight decay + if weight_decay != 0: + torch._foreach_mul_(device_params, 1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1) + + torch._foreach_mul_(device_exp_avg_sqs, beta2) + torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, 1 - beta2) + + # Delete the local intermediate since it won't be used anymore to save on peak memory + del device_grads + + bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]] + bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]] + bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]] + + if capturable: + bias_correction1 = torch._foreach_pow(beta1, device_state_steps) + bias_correction2 = torch._foreach_pow(beta2, device_state_steps) + # foreach_sub doesn't allow a scalar as the first arg + torch._foreach_sub_(bias_correction1, 1) + torch._foreach_sub_(bias_correction2, 1) + # we do not negate bias_correction1 as it'll need to be negated later anyway + torch._foreach_neg_(bias_correction2) + + # foreach_div doesn't allow a scalar as the first arg + torch._foreach_div_(bias_correction1, lr) + torch._foreach_reciprocal_(bias_correction1) + + torch._foreach_sqrt_(bias_correction2) + + # Re-assign for clarity as we maintain minimal intermediates: we'll have + # step_size = - lr / (1 - beta1 ^ t) where t = num_steps + # bias_correction2_sqrt = sqrt(1 - beta2 ^ t) + step_size = bias_correction1 + bias_correction2_sqrt = bias_correction2 + + if amsgrad: + device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) + + # Maintains the maximum of all 2nd moment running avg. till now + torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) + + # Use the max. for normalizing running avg. of gradient + exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) + else: + exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) + + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) + torch._foreach_add_(exp_avg_sq_sqrt, eps) + torch._foreach_div_(exp_avg_sq_sqrt, step_size) + + # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr + torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt) + else: + bias_correction1 = [1 - beta1**_get_value(step) for step in device_state_steps] + bias_correction2 = [1 - beta2**_get_value(step) for step in device_state_steps] + + step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1]) + + bias_correction2_sqrt = [ + bc**0.5 for bc in bias_correction2 # type: ignore[arg-type] + ] + + if amsgrad: + device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) + + # Maintains the maximum of all 2nd moment running avg. till now + torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) + + # Use the max. for normalizing running avg. of gradient + exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) + else: + exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) + + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) + torch._foreach_add_(exp_avg_sq_sqrt, eps) + torch._foreach_addcdiv_( + device_params, + device_exp_avgs, + exp_avg_sq_sqrt, + step_size, # type: ignore[arg-type] + ) + + +def _fused_adamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: Union[Tensor, float], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, # Needed for consistency. + differentiable: bool, + has_complex: bool, # Needed for consistency. +) -> None: + if not params: + return + if differentiable: + raise RuntimeError("Adam with fused=True does not support differentiable=True") + + grad_scale_dict: DeviceDict = ({grad_scale.device: grad_scale} if grad_scale is not None else {}) + found_inf_dict: DeviceDict = ({found_inf.device: found_inf} if found_inf is not None else {}) + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: Optional[DeviceDict] = ({lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None) + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_exp_avgs = cast(List[Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) + device_state_steps = cast(List[Tensor], device_state_steps_) + + if device.type == "mps": # type: ignore[union-attr] + assert found_inf is None and grad_scale is None + + device_grad_scale, device_found_inf = None, None + if grad_scale is not None: + device_grad_scale = grad_scale_dict.setdefault(device, grad_scale.to(device, non_blocking=True)) + if found_inf is not None: + device_found_inf = found_inf_dict.setdefault(device, found_inf.to(device, non_blocking=True)) + if lr_dict is not None and device not in lr_dict: + lr = lr_dict.setdefault( + device, + lr.to(device=device, non_blocking=True) # type: ignore[union-attr] + ) + torch._foreach_add_(device_state_steps, 1) + torch._fused_adamw_( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + grad_scale=device_grad_scale, + found_inf=device_found_inf, + ) + if device_found_inf is not None: + torch._foreach_sub_(device_state_steps, [device_found_inf] * len(device_state_steps)) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamw) +def adamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + has_complex: bool = False, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + maximize: bool, +): + r"""Functional API that performs AdamW algorithm computation. + + See :class:`~torch.optim.AdamW` for details. + """ + if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") + + # Respect when the user inputs False/True for foreach or fused. We only want to change + # the default when neither have been user-specified. Note that we default to foreach + # and pass False to use_fused. This is not a mistake--we want to give the fused impl + # bake-in time before making it the default, even if it is typically faster. + if fused is None and foreach is None: + _, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False) + # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False. + if foreach and isinstance(lr, Tensor) and not capturable: + foreach = False + if fused is None: + fused = False + if foreach is None: + foreach = False + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + if fused and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with fused optimizers") + + if fused and not torch.jit.is_scripting(): + func = _fused_adamw + elif foreach and not torch.jit.is_scripting(): + func = _multi_tensor_adamw + else: + func = _single_tensor_adamw + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + has_complex=has_complex, + ) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 861f7d23c9c2..e90831382146 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -8,6 +8,7 @@ from packaging import version as pkg_version from collections import OrderedDict from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from deepspeed.runtime.zero.zenflow import zenflow_utils from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler @@ -111,6 +112,7 @@ def __init__(self, init_optimizer, param_names, timers, + optimizer_params, static_loss_scale=1.0, dynamic_loss_scale=False, dynamic_loss_args=None, @@ -125,6 +127,7 @@ def __init__(self, reduce_scatter=True, overlap_comm=False, offload_optimizer_config=None, + zenflow_config=None, mpu=None, clip_grad=0.0, gradient_accumulation_dtype=torch.float32, @@ -146,6 +149,8 @@ def __init__(self, self.cpu_offload = False self.cpu_offload_pin_memory = False + self.zenflow = True if zenflow_config is not None else False + if dist.get_rank() == 0: logger.info(f"Reduce bucket size {reduce_bucket_size}") logger.info(f"Allgather bucket size {allgather_bucket_size}") @@ -167,9 +172,9 @@ def __init__(self, raise SystemError("Accelerator is not detected, cannot perform low precision training (e.g., fp16, bf16).") self.optimizer = init_optimizer - # Use torch (un)flatten ops - self.flatten = _flatten_dense_tensors - self.unflatten = _unflatten_dense_tensors + # Use torch or zenflow (un)flatten ops + self.flatten = _flatten_dense_tensors if not self.zenflow else zenflow_utils._flatten_dense_tensors + self.unflatten = _unflatten_dense_tensors if not self.zenflow else zenflow_utils._unflatten_dense_tensors # ZeRO stage 1 (False) or 2 (True) self.partition_gradients = partition_grads @@ -965,8 +970,12 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): else: # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel()) - new_grad_tensor.copy_(grad_reduc.view(-1)) - grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc) + new_grad_tensor.copy_( + grad_reduc.view(-1) if not self.zenflow else grad_reduc.permute( + *reversed(range(grad_reduc.ndim))).contiguous().view(-1)) + grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc) if ( + not self.zenflow or grad_reduc.dim() == 1) else new_grad_tensor.data.view_as( + grad_reduc.transpose(0, 1)) self.elements_in_ipg_bucket += param.numel() @@ -1376,7 +1385,13 @@ def reduce_ipg_grads(self): assert self.get_param_id(self.extra_large_param_to_reduce ) == param_id, "param in ipg bucket does not match extra-large param" extra_large_grad_reduc = self.get_gradient_for_reduction(self.extra_large_param_to_reduce) - self.average_tensor(extra_large_grad_reduc.view(-1)) + + extra_large_grad_reduc_for_average = extra_large_grad_reduc.view(-1) if not self.zenflow \ + else extra_large_grad_reduc.permute(*reversed(range(extra_large_grad_reduc.ndim))).contiguous().view(-1) + extra_large_grad_reduc.data = extra_large_grad_reduc_for_average.data.view_as(extra_large_grad_reduc) if (not self.zenflow or self.extra_large_param_to_reduce.dim() == 1) \ + else extra_large_grad_reduc_for_average.data.view_as(extra_large_grad_reduc.transpose(0, 1)) + + self.average_tensor(extra_large_grad_reduc_for_average) self.extra_large_param_to_reduce = None else: self.average_tensor(self.ipg_buffer[self.ipg_index].narrow(0, 0, self.elements_in_ipg_bucket)) @@ -1826,7 +1841,11 @@ def _optimizer_step(self, group_no): # self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)]) #else: # self.optimizer.step() - self.optimizer.step() + if self.zenflow: + self.zenflow_cpu_optimizer_step(group_no) + else: + self.optimizer.step() + self.optimizer.param_groups = original_param_groups # We need to link optimizer state after the first step() call diff --git a/deepspeed/runtime/zero/zenflow/__init__.py b/deepspeed/runtime/zero/zenflow/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/deepspeed/runtime/zero/zenflow/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py new file mode 100644 index 000000000000..c1489848cb13 --- /dev/null +++ b/deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py @@ -0,0 +1,929 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import torch +from deepspeed import comm as dist +import torch.multiprocessing as mp + +from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer +from deepspeed.runtime.utils import (see_memory_usage) +from deepspeed.ops.adam import ZenFlowSelectiveAdamW + +from deepspeed.moe.utils import is_moe_param + +from deepspeed.accelerator import get_accelerator + +from deepspeed.runtime.utils import all_gather_dp_groups + +# Toggle this to true to enable correctness test +# with gradient partitioning and without +pg_correctness_test = False + +OPTIMIZER_ALLGATHER_TIMER = 'optimizer_allgather' +OPTIMIZER_GRADIENTS_TIMER = 'optimizer_gradients' +OPTIMIZER_STEP_TIMER = 'optimizer_step' +OPTIMIZER_TRANSMIT_TIMER = 'optimizer_transmit_time' +OPTIMIZER_CALC_TIMER = 'optimizer_calc_time' +OPTIMIZER_RECV_PARAMS_TIMER = 'optimizer_receive_params_time' +OPTIMIZER_UPDATE_MODEL_TIMER = 'optimizer_update_model_time' +OPTIMIZER_TIMERS = [ + OPTIMIZER_ALLGATHER_TIMER, OPTIMIZER_GRADIENTS_TIMER, OPTIMIZER_STEP_TIMER, OPTIMIZER_TRANSMIT_TIMER, + OPTIMIZER_CALC_TIMER, OPTIMIZER_RECV_PARAMS_TIMER, OPTIMIZER_UPDATE_MODEL_TIMER +] +INITIAL_MICRO_STEP_ID = -1 + +SELECTIVE_OPTIMIZER_UPDATE_TIMER = 'selective_optimizer_update' +SELECTIVE_OPTIMIZER_PROCESS_TIMER = 'selective_optimizer_process' +SELECTIVE_OPTIMIZER_STEP_TIMER = 'selective_optimizer_step' +SELECTIVE_OPTIMIZER_SYNC_TIMER = 'selective_optimizer_sync' +SELECTIVE_OPTIMIZER_TIMERS = [ + SELECTIVE_OPTIMIZER_UPDATE_TIMER, SELECTIVE_OPTIMIZER_PROCESS_TIMER, SELECTIVE_OPTIMIZER_STEP_TIMER, + SELECTIVE_OPTIMIZER_SYNC_TIMER +] + + +class ZenFlowZeroOptimizer(DeepSpeedZeroOptimizer): + + def __init__(self, + init_optimizer, + param_names, + timers, + optimizer_params, + static_loss_scale=1.0, + dynamic_loss_scale=False, + dynamic_loss_args=None, + verbose=True, + contiguous_gradients=True, + reduce_bucket_size=500000000, + use_multi_rank_bucket_allreduce=True, + allgather_bucket_size=5000000000, + dp_process_group=None, + expert_parallel_group=None, + expert_data_parallel_group=None, + reduce_scatter=True, + overlap_comm=False, + offload_optimizer_config=None, + zenflow_config=None, + mpu=None, + clip_grad=0.0, + gradient_accumulation_dtype=torch.float32, + communication_data_type=torch.float16, + postscale_gradients=True, + gradient_predivide_factor=1.0, + gradient_accumulation_steps=1, + ignore_unused_parameters=True, + partition_grads=True, + round_robin_gradients=False, + has_moe_layers=False, + fp16_master_weights_and_gradients=False, + elastic_checkpoint=False): + + super().__init__(init_optimizer, param_names, timers, optimizer_params, static_loss_scale, dynamic_loss_scale, + dynamic_loss_args, verbose, contiguous_gradients, reduce_bucket_size, + use_multi_rank_bucket_allreduce, allgather_bucket_size, dp_process_group, + expert_parallel_group, expert_data_parallel_group, reduce_scatter, overlap_comm, + offload_optimizer_config, zenflow_config, mpu, clip_grad, gradient_accumulation_dtype, + communication_data_type, postscale_gradients, gradient_predivide_factor, + gradient_accumulation_steps, ignore_unused_parameters, partition_grads, round_robin_gradients, + has_moe_layers, fp16_master_weights_and_gradients, elastic_checkpoint) + + self.micro_step = -1 + self.full_warm_up_rounds = zenflow_config.full_warm_up_rounds + self.offload_selective_optimizer = zenflow_config.offload + + if self.offload_selective_optimizer: + assert overlap_comm, "offload selective optimizer should be used with overlap_comm" + + self._configure_zenflow(zenflow_config) + + + self.selective_optimizer = ZenFlowSelectiveAdamW([{"params": group} for group in self.bit16_groups], \ + offload=zenflow_config.offload, + bucket_size=self.allgather_bucket_size, + **optimizer_params) + self.num_total_param = sum(sum(1 for param in group if len(param.shape) != 1) for group in self.bit16_groups) + + @classmethod + def create(cls, zenflow_config): + if zenflow_config.overlap_step: + return ZenFlowZeroOptimizerParallel + else: + return ZenFlowZeroOptimizerSequential + + def _configure_zenflow(self, zenflow_config): + """ + Configure ZenFlow optimizer + """ + if not self.cpu_offload: + raise ValueError("Zenflow must be used with cpu offload") + + self.select_strategy = zenflow_config.select_strategy + if self.select_strategy == 'auto': + self.select_strategy = "epoch" + if isinstance(zenflow_config.select_interval, int): + raise Warning( + "If use auto select strategy, select_interval will be set to 1 and select_strategy will be set to epoch, thus select_interval would be overwritten." + ) + self.select_interval = 1 + else: + if isinstance(zenflow_config.select_interval, str): + raise ValueError("If don't use auto select strategy, select_interval must be a number.") + self.select_interval = int(zenflow_config.select_interval) + + if isinstance(zenflow_config.update_interval, str): + self.auto_update = True + self.update_interval = 0 + else: + self.auto_update = False + self.update_interval = int(zenflow_config.update_interval) + + if self.select_strategy == 'epoch': + self.select_interval = self.select_interval * zenflow_config.steps_per_epoch + + if not self.auto_update and self.select_interval != 0 and self.select_interval < self.update_interval: + raise ValueError("Select interval must be greater or equal to update interval") + + self.topk_ratio = zenflow_config.topk_ratio + + self.param_id_index_buffer_offset = {} + self.param_id_grad_buffer_offset = {} + + if self.auto_update: + self.param_id_sum_buffer_offset = {} + self.auto_ratio = zenflow_config.auto_ratio + self.zenflow_need_update = [False, False] + self.zenflow_state = 0 + self.num_need_update = 0 + + def is_zenflow_select_boundary(self): + return self.zenflow and (self.micro_step - self.full_warm_up_rounds) >= 0 and ( + (self.micro_step - self.full_warm_up_rounds) == 0 or + (self.select_interval != 0 and self.micro_step % self.select_interval == 0)) + + def sync_fp32_param_from_gpu(self): + if self.micro_step == 0: + return + + for i, group in enumerate(self.bit16_groups): + partition_id = dist.get_rank(group=self.real_dp_process_group[i]) + + bit16_partitions = self.parallel_partitioned_bit16_groups[i] + fp32_partition = self.single_partition_of_fp32_groups[i] + + with torch.no_grad(): + fp32_partition.copy_(bit16_partitions[partition_id].to(dtype=fp32_partition.dtype, + device=fp32_partition.device)) + + def update_selected_channels(self, tensor, total_size): + curr_size = 0 + curr_index_buffer_size = 0 + rank_and_offsets = [] + prev_id, prev_process_group = -1, None + + process_group = self.dp_process_group + rank = dist.get_rank(process_group) + + self.index_buffer = torch.empty(total_size, dtype=torch.int32, device='cuda') + + # count = 0 + for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: + param = self.bit16_groups[i][param_idx_in_group] + + if len(param.shape) == 1: + continue + + if not hasattr(param, 'selected_indices'): + param.selected_indices = None + + partition_ids = self.param_to_partition_ids[i][param_id] + + # Get all partition ids + their offsets + partition_ids_w_offsets = [] + for partition_id in partition_ids: + offset = self.grad_start_offset[i][partition_id][param_id] + partition_ids_w_offsets.append((partition_id, offset)) + partition_ids_w_offsets.sort(key=lambda t: t[1]) + + # Calculate rank and offsets for grad slices + for idx in range(len(partition_ids_w_offsets)): + partition_id, offset = partition_ids_w_offsets[idx] + + if idx == len(partition_ids_w_offsets) - 1: + numel = param.numel() - offset + else: + numel = partition_ids_w_offsets[idx + 1][1] - offset + + num_row, num_col = param.shape if len(param.shape) == 2 else (1, param.shape[0]) + start_column = 0 if not offset else int((offset - 1) / num_row) + 1 + end_column = int((offset + numel) / num_row) + num_select = int(self.topk_ratio * (end_column - start_column)) + + if partition_id == rank: + + start_idx = int(curr_size + start_column * num_row - offset) + num_elements = (end_column - start_column) * num_row + sum_per_column = tensor.narrow(0, start_idx, num_elements) + sum_per_column = sum_per_column.view(end_column - start_column, num_row) + sum_array = sum_per_column.abs().sum(dim=1) + + _, top_indices = torch.topk(sum_array, num_select) + top_indices += start_column + self.index_buffer.narrow(0, curr_index_buffer_size, num_select).copy_(top_indices) + + if partition_id == prev_id and process_group == prev_process_group: + prev_pid, prev_size, prev_numel = rank_and_offsets[-1] + rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + num_select) + else: + rank_and_offsets.append((partition_id, curr_index_buffer_size, num_select)) + + if param_id not in self.param_id_index_buffer_offset: + self.param_id_index_buffer_offset[param_id] = [] + self.param_id_index_buffer_offset[param_id].append((curr_index_buffer_size, num_select)) + + curr_size += numel + curr_index_buffer_size += num_select + + for src_rank, offset, num_select in rank_and_offsets: + index_slice = self.index_buffer.narrow(0, offset, num_select) + dist.broadcast(index_slice, src=src_rank, group=process_group) + + for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: + param = self.bit16_groups[i][param_idx_in_group] + + if len(param.shape) == 1: + continue + + param.selected_indices = None + param.partition_selected_indices = [] + + for offset, num_select in self.param_id_index_buffer_offset[param_id]: + selected = self.index_buffer.narrow(0, offset, num_select).clone().sort()[0] + if param.selected_indices is None: + param.selected_indices = selected + else: + param.selected_indices = torch.cat([param.selected_indices, selected]) + param.partition_selected_indices.append(selected) + + self.param_id_index_buffer_offset[param_id] = [] + + num_row, num_col = param.shape if len(param.shape) == 2 else (1, param.shape[0]) + param.selected_indices.sort() + param.selected_shape = (param.selected_indices.shape[0], + num_row) if num_row != 1 else (param.selected_indices.shape[0], ) + + self.index_buffer = None + + def process_selected_fp32_groups_grad(self, tensor, total_size): + """ + Process gradients for selected columns in FP32 groups + + Args: + param: The parameter to process + param_id: ID of the parameter + """ + + curr_size = 0 + curr_grad_buffer_size = 0 + curr_sum_buffer_size = 0 + rank_and_offsets = [] + prev_id, prev_process_group = -1, None + + process_group = self.dp_process_group + rank = dist.get_rank(process_group) + + self.grad_buffer = torch.empty(total_size, dtype=self.dtype, device='cuda') + + if self.auto_update: + self.sum_buffer = torch.empty(len(self.params_in_ipg_bucket) + dist.get_world_size(group=process_group), + dtype=torch.bfloat16, + device='cuda') + + group_to_paramlist = {} + + # count = 0 + for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: + param = self.bit16_groups[i][param_idx_in_group] + + if not hasattr(param, 'selected_indices'): + param.selected_indices = None + + partition_ids = self.param_to_partition_ids[i][param_id] + + # Get all partition ids + their offsets + partition_ids_w_offsets = [] + for partition_id in partition_ids: + offset = self.grad_start_offset[i][partition_id][param_id] + partition_ids_w_offsets.append((partition_id, offset)) + partition_ids_w_offsets.sort(key=lambda t: t[1]) + + # Calculate rank and offsets for grad slices + for idx in range(len(partition_ids_w_offsets)): + partition_id, offset = partition_ids_w_offsets[idx] + + if idx == len(partition_ids_w_offsets) - 1: + numel = param.numel() - offset + else: + numel = partition_ids_w_offsets[idx + 1][1] - offset + + num_row, num_col = param.shape if len(param.shape) == 2 else (1, param.shape[0]) + start_column = 0 if not offset else int((offset - 1) / num_row) + 1 + end_column = int((offset + numel) / num_row) + num_select = int(self.topk_ratio * (end_column - start_column)) if len(param.shape) == 2 else numel + grad_size = num_select * num_row + + if partition_id == rank: + selected_grad = param.grad[ + param.partition_selected_indices[idx], :] if num_row != 1 else param.grad[offset:offset + + numel] + self.grad_buffer.narrow(0, curr_grad_buffer_size, grad_size).copy_(selected_grad.view(-1)) + + if self.auto_update: + self.sum_buffer[curr_sum_buffer_size] = tensor.narrow(0, int(curr_size), + int(numel)).abs().sum() + + if partition_id == prev_id and process_group == prev_process_group: + if self.auto_update: + prev_pid, prev_size, prev_numel, prev_sum_size, prev_sum_num = rank_and_offsets[-1] + rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + grad_size, prev_sum_size, + prev_sum_num + 1) + else: + prev_pid, prev_size, prev_numel = rank_and_offsets[-1] + rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + grad_size) + else: + if self.auto_update: + rank_and_offsets.append( + (partition_id, curr_grad_buffer_size, grad_size, curr_sum_buffer_size, 1)) + else: + rank_and_offsets.append((partition_id, curr_grad_buffer_size, grad_size)) + + if param_id not in self.param_id_grad_buffer_offset: + self.param_id_grad_buffer_offset[param_id] = [] + if self.auto_update and param_id not in self.param_id_sum_buffer_offset: + self.param_id_sum_buffer_offset[param_id] = [] + self.param_id_grad_buffer_offset[param_id].append((curr_grad_buffer_size, grad_size)) + if self.auto_update: + self.param_id_sum_buffer_offset[param_id].append(curr_sum_buffer_size) + + curr_size += numel + curr_grad_buffer_size += grad_size + curr_sum_buffer_size += 1 + + for item in rank_and_offsets: + if self.auto_update: + src_rank, offset, grad_size, sum_offset, sum_num = item + else: + src_rank, offset, grad_size = item + + grad_slice = self.grad_buffer.narrow(0, offset, grad_size) + dist.broadcast(grad_slice, src=src_rank, group=process_group) + + if self.auto_update: + sum_slice = self.sum_buffer.narrow(0, sum_offset, sum_num) + dist.broadcast(sum_slice, src=src_rank, group=process_group) + + for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: + param = self.bit16_groups[i][param_idx_in_group] + + selected_grad = None + for offset, grad_size in self.param_id_grad_buffer_offset[param_id]: + selected_grad_buffer = self.grad_buffer.narrow(0, offset, grad_size).clone().detach() + if selected_grad is None: + selected_grad = selected_grad_buffer + else: + selected_grad = torch.cat([selected_grad, selected_grad_buffer]) + param.selected_grad = selected_grad.view(param.selected_shape).t() if len( + param.shape) != 1 else selected_grad + + if self.offload_selective_optimizer and not hasattr(param, 'exp_avg_cpu_data'): + buffer = torch.zeros(param.selected_grad.numel(), dtype=param.dtype, device=self.device) + param.exp_avg_cpu_data = get_accelerator().pin_memory( + buffer) if self.cpu_offload_pin_memory else buffer + param.exp_avg_sq_cpu_data = get_accelerator().pin_memory( + buffer.clone()) if self.cpu_offload_pin_memory else buffer.clone() + + param_list = group_to_paramlist.setdefault(i, []) + param_list.append(param) + + self.param_id_grad_buffer_offset[param_id] = [] + + if self.auto_update: + grad_total_sum = 0 + num_row, num_col = param.shape if len(param.shape) == 2 else (1, param.shape[0]) + if num_row == 1: + continue + + for offset in self.param_id_sum_buffer_offset[param_id]: + grad_total_sum += self.sum_buffer.narrow(0, offset, 1) + + grad_critic_sum = param.selected_grad.abs().sum() + + if not hasattr(param, 'non_critic_sum'): + param.non_critic_sum = 0 + if not hasattr(param, 'avg_critic_sum'): + param.avg_critic_sum = 0 + + param.avg_critic_sum = (param.avg_critic_sum * (self.update_interval - 1) + + grad_critic_sum) / self.update_interval / (self.topk_ratio * 10) + param.non_critic_sum += (grad_total_sum - grad_critic_sum) / ((1 - self.topk_ratio) * 10) + if param.non_critic_sum >= param.avg_critic_sum: + self.num_need_update += 1 + + if self.num_need_update >= int(self.auto_ratio * self.num_total_param): + self.zenflow_need_update[self.zenflow_state] = True + + self.param_id_sum_buffer_offset[param_id] = [] + + if not self.is_gradient_accumulation_boundary: + self.selective_optimizer.group_step(group_to_paramlist) + else: + self.selective_optimizer.temp_copy_param(group_to_paramlist) + + self.grad_buffer = None + if self.auto_update: + self.sum_buffer = None + + def average_tensor(self, tensor): + if self.overlap_comm: + stream = self.reduction_stream + if not get_accelerator().resolves_data_dependency(): + stream.wait_stream(get_accelerator().current_stream()) + get_accelerator().current_stream().wait_stream(stream) + else: + stream = get_accelerator().current_stream() + + with get_accelerator().stream(stream): + if not self.reduce_scatter: + self.gradient_reduction_w_predivide(tensor) + return + + # Accumulate destination ranks and bucket offsets for each gradient slice. + # Note: potential future optimization, record access pattern of parameters + # in backward pass and partition gradients w.r.t. access pattern so that our + # bucket is guaranteed to be contiguous w.r.t. ranks + rank_and_offsets = [] + real_dp_process_group = [] + curr_size = 0 + prev_id, prev_process_group = -1, None + + curr_column_size = 0 + curr_selected_reduce_size = 0 + + process_group = self.dp_process_group + # count = 0 + for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: + param = self.bit16_groups[i][param_idx_in_group] + + process_group = self.dp_process_group + + if self.ipg_bucket_has_moe_params: + process_group = self.expert_dp_process_group[param.group_name] if is_moe_param( + param) else self.dp_process_group + + partition_ids = self.param_to_partition_ids[i][param_id] + assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids + ]), f"world size {dist.get_world_size(group=process_group)} and p_ids: {partition_ids}" + partition_size = self.partition_size[i] + # Get all partition ids + their offsets + partition_ids_w_offsets = [] + for partition_id in partition_ids: + offset = self.grad_start_offset[i][partition_id][param_id] + partition_ids_w_offsets.append((partition_id, offset)) + partition_ids_w_offsets.sort(key=lambda t: t[1]) + + num_row, num_col = param.shape if len(param.shape) == 2 else (1, param.shape[0]) + curr_column_size += int(num_col * self.topk_ratio) if num_row != 1 else 0 + + # Calculate rank and offsets for grad slices + for idx in range(len(partition_ids_w_offsets)): + partition_id, offset = partition_ids_w_offsets[idx] + + # if dist.get_rank() == 0 and count < 100: + # print(f"Rank {dist.get_rank()} rank offset id {idx} calculated dp size {dist.get_world_size(group=process_group)} real dp size {dist.get_world_size(self.real_dp_process_group[i])} and dst: {partition_id}") + # count += 1 + + # Calculate numel for grad slice depending on partition location + if idx == len(partition_ids_w_offsets) - 1: + # Last partition_id uses its own offset + numel = param.numel() - offset + else: + # Set numel to next partition's offset + numel = partition_ids_w_offsets[idx + 1][1] - offset + + # Merge bucket ranges if they belong to the same rank + if partition_id == prev_id and process_group == prev_process_group: + prev_pid, prev_size, prev_numel = rank_and_offsets[-1] + rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + numel) + else: + rank_and_offsets.append((partition_id, curr_size, numel)) + real_dp_process_group.append(process_group) + curr_size += numel + curr_selected_reduce_size += int(numel * self.topk_ratio) if num_row != 1 else numel + + prev_id, prev_process_group = partition_id, process_group + + tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) + + buckets = {} + for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets): + grad_slice = tensor.narrow(0, int(bucket_offset), int(numel)) + bucket_key = real_dp_process_group[i] if self.use_multi_rank_bucket_allreduce else ( + dst, real_dp_process_group[i]) + if bucket_key not in buckets: + buckets[bucket_key] = [] + if self.use_multi_rank_bucket_allreduce: + buckets[bucket_key].append((dst, grad_slice)) + else: + buckets[bucket_key].append(grad_slice) + + for bucket_key in buckets: + if self.use_multi_rank_bucket_allreduce: + self.allreduce_and_scatter(buckets[bucket_key], + numel_per_bucket=self.reduce_bucket_size, + divide=False, + process_group=bucket_key) + else: + dst, process_group = bucket_key + self.allreduce_no_retain(buckets[bucket_key], + numel_per_bucket=self.reduce_bucket_size, + rank=dst, + divide=False, + process_group=process_group) + + if self.is_zenflow_select_boundary(): + self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).start() + # print("update selected") + self.update_selected_channels(tensor, curr_column_size) + self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).stop() + elif self.zenflow: + self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).start() + self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).stop() + + if self.zenflow and self.micro_step >= self.full_warm_up_rounds: + self.timers(SELECTIVE_OPTIMIZER_PROCESS_TIMER).start() + self.process_selected_fp32_groups_grad(tensor, curr_selected_reduce_size) + self.timers(SELECTIVE_OPTIMIZER_PROCESS_TIMER).stop() + + def backward(self, loss, retain_graph=False): + """ + :attr:`backward` performs the following steps: + + 1. fp32_loss = loss.float() + 2. scaled_loss = fp32_loss*loss_scale + 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves + """ + self.backward_prologue() + self.micro_step += 1 + + if self.auto_update: + self.zenflow_need_update[self.zenflow_state] = False + self.num_need_update = 0 + if self.zenflow_need_update[self.zenflow_state ^ 1]: + self.update_interval = 0 + for group in self.bit16_groups: + for p in group: + p.non_critic_sum = 0 + self.update_interval += 1 + + if self.is_zenflow_select_boundary(): + self.timers(SELECTIVE_OPTIMIZER_SYNC_TIMER).start() + self.sync_fp32_param_from_gpu() + self.selective_optimizer.clear_selected_mv() + self.timers(SELECTIVE_OPTIMIZER_SYNC_TIMER).stop() + + if self.custom_loss_scaler: + scaled_loss = self.external_loss_scale * loss + scaled_loss.backward(retain_graph=retain_graph) + else: + self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + + self.backward_epilogue() + + def log_selective_optimizer_timers(self): + self.timers.log(SELECTIVE_OPTIMIZER_TIMERS) + + def _sync_selective_optimizer_lr(self): + for group_selected, group in zip(self.selective_optimizer.param_groups, self.optimizer.param_groups): + group_selected["lr"] = group["lr"] + + def _selective_optimizer_step(self, group_no): + original_param_groups = self.selective_optimizer.param_groups + self.selective_optimizer.param_groups = [original_param_groups[group_no]] + self.selective_optimizer.step() + self.selective_optimizer.param_groups = original_param_groups + + def selective_optimizer_step(self, closure=None): + for i, group in enumerate(self.bit16_groups): + self.timers(SELECTIVE_OPTIMIZER_STEP_TIMER).start() + self._selective_optimizer_step(i) + self.timers(SELECTIVE_OPTIMIZER_STEP_TIMER).stop() + + self.timers.log(SELECTIVE_OPTIMIZER_TIMERS) + + +class ZenFlowZeroOptimizerSequential(ZenFlowZeroOptimizer): + + def __init__(self, *args, **kwargs): + super(ZenFlowZeroOptimizerSequential, self).__init__(*args, **kwargs) + + def zenflow_cpu_optimizer_step(self, group_no): + self.optimizer.step(step_id=self.micro_step + 1) + + +def disable_accelerator(): + accelerator = get_accelerator() + accelerator.is_available = lambda: False + accelerator.device_count = lambda: 0 + accelerator.current_device = lambda: -1 + # Optionally mark it as initialized if needed + if hasattr(accelerator, "_initialized"): + accelerator._initialized = True + + +def zenflow_optimizer_process(pipe, curr_rank, total_rank, param_groups, shared_overlap_grad_map, + shared_stale_param_map): + os.environ["CUDA_VISIBLE_DEVICES"] = "" + disable_accelerator() + + CPUADAM_CORE_START = 65 + CPUADAM_CORE_END = 112 + TOTAL_CORES = CPUADAM_CORE_END - CPUADAM_CORE_START + + cores_per_rank = TOTAL_CORES // total_rank + extra = TOTAL_CORES % total_rank + start_offset = curr_rank * cores_per_rank + min(curr_rank, extra) + end_offset = start_offset + cores_per_rank + (1 if curr_rank < extra else 0) + assigned_cores = set(range(CPUADAM_CORE_START + start_offset, CPUADAM_CORE_START + end_offset)) + + try: + os.sched_setaffinity(0, assigned_cores) + print(f"[Optimizer Thread] Rank {curr_rank} bound to CPU cores: {os.sched_getaffinity(0)}", flush=True) + except AttributeError: + print("[Optimizer Thread] sched_setaffinity not supported on this system.") + except Exception as e: + print(f"[Optimizer Thread] Failed to set affinity: {e}") + + from deepspeed.ops.adam import ZenFlowCPUAdam + optimizer = ZenFlowCPUAdam(param_groups, overlap_step=True) + + pipe.send({"type": "ready"}) + + # TODO: replace this with rpc + + while True: + cmd = pipe.recv() + if cmd["type"] == "step": + now_state = cmd["now_state"] + micro_step = cmd["micro_step"] + group_infos = cmd["group_infos"] + + for group_no, group_info in enumerate(group_infos): + original_param_groups = optimizer.param_groups + optimizer.param_groups = [original_param_groups[group_no]] + group = optimizer.param_groups[0] + + for param_idx, param in enumerate(group["params"]): + key = (group_no, param_idx) + if key in shared_overlap_grad_map: + param.overlap_grad = shared_overlap_grad_map[key] + if key in shared_stale_param_map: + param.stale_param = shared_stale_param_map[key] + + optimizer.step(step_id=micro_step + 1, now_state=now_state, group_info=group_info) + + optimizer.param_groups = original_param_groups + + pipe.send({"type": "done"}) + elif cmd["type"] == "exit": + break + + +class ZenFlowZeroOptimizerParallel(ZenFlowZeroOptimizer): + + def __init__(self, *args, **kwargs): + super(ZenFlowZeroOptimizerParallel, self).__init__(*args, **kwargs) + self.process_pool = mp.Pool(1) + self.process_optimizer_established = False + self.first_update_round_after_warmup = True + + def initialize_optimizer_states(self): + + for i, group in enumerate(self.bit16_groups): + single_grad_partition = torch.zeros(int(self.partition_size[i]), + dtype=self.single_partition_of_fp32_groups[i].dtype, + device=self.device) + self.single_partition_of_fp32_groups[i].grad = None + buffer = get_accelerator().pin_memory( + single_grad_partition) if self.cpu_offload_pin_memory else single_grad_partition + self.single_partition_of_fp32_groups[i].overlap_grad = [buffer, buffer.clone()] + + # Initialize the optimizer states with the flattened fp32 partition. + # State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers + # which do lazy initialization of the state at the first call to step. + if isinstance(self.optimizer, torch.optim.Adagrad): + self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults) + + if not self.cpu_offload: + for group in self.single_partition_of_fp32_groups: + group.grad = None #class init + + return + + def _get_offload_gradient_dict(self): + for param_group_index, _ in enumerate(self.optimizer.param_groups): + self.offload_gradient_dict[param_group_index] = [] + for lp_param in self.params_in_partition[param_group_index]: + param_id = self.get_param_id(lp_param) + [_, _, dest_offset, num_elements] = self.grad_position[param_id] + dest_tensor = self.single_partition_of_fp32_groups[param_group_index].overlap_grad[0].view(-1).narrow( + 0, dest_offset, num_elements) + self.offload_gradient_dict[param_group_index].append(dest_tensor) + + def get_overlap_step_state(self): + if self.micro_step < self.full_warm_up_rounds: + return self.micro_step & 1 + else: + if not self.auto_update: + return (self.micro_step // self.update_interval) & 1 + else: + return self.zenflow_state + + def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param): + param_id = self.get_param_id(param) + now_state = self.get_overlap_step_state() + + [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] + + dest_tensor = self.single_partition_of_fp32_groups[i].overlap_grad[now_state].view(-1).narrow( + 0, dest_offset, num_elements) + + grad_accum = self.get_param_gradient_attribute(param) + if grad_accum is None: + src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements) + else: + src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements) + if not self.fp16_master_weights_and_gradients: + src_tensor = src_tensor.float() + + dest_tensor.copy_(src_tensor, non_blocking=True) + param.grad = None #offload only + + def start_optimizer_process(self): + from multiprocessing import Pipe, get_context, Manager + + ctx = get_context("spawn") + self.parent_conn, self.child_conn = Pipe() + + manager = Manager() + self.shared_overlap_grad_map = manager.dict() + self.shared_stale_param_map = manager.dict() + + for group_no, group in enumerate(self.optimizer.param_groups): + for param_idx, param in enumerate(group['params']): + param.data.share_memory_() + if not hasattr(param, 'stale_param'): + param.stale_param = torch.zeros_like(param.data, dtype=param.dtype, device=param.device) + param.stale_param.data.share_memory_() + key = (group_no, param_idx) + self.shared_stale_param_map[key] = param.stale_param + if param.overlap_grad is not None: + param.overlap_grad[0].data.share_memory_() + param.overlap_grad[1].data.share_memory_() + key = (group_no, param_idx) + self.shared_overlap_grad_map[key] = param.overlap_grad + + param_groups_data = self.optimizer.param_groups + curr_rank = dist.get_rank() + total_rank = dist.get_world_size() + + self.process = ctx.Process( + target=zenflow_optimizer_process, + args=(self.child_conn, curr_rank, total_rank, param_groups_data, self.shared_overlap_grad_map, + self.shared_stale_param_map), + ) + self.process.daemon = True + self.process.start() + + msg = self.parent_conn.recv() + assert msg["type"] == "ready", "Optimizer process did not initialize correctly." + + self.process_optimizer_established = True + + def wait_last_update_and_copy(self): + + if not hasattr(self, 'parent_conn'): + return + + if self.micro_step + 1 > self.full_warm_up_rounds and self.first_update_round_after_warmup: + self.first_update_round_after_warmup = False + return + + self.timers(OPTIMIZER_RECV_PARAMS_TIMER).start() + msg = self.parent_conn.recv() + assert msg["type"] == "done", "Optimizer process did not finish stepping correctly." + self.timers(OPTIMIZER_RECV_PARAMS_TIMER).stop() + + for i, group in enumerate(self.bit16_groups): + partition_id = dist.get_rank(group=self.real_dp_process_group[i]) + bit16_partitions = self.parallel_partitioned_bit16_groups[i] + fp32_partition = self.optimizer.param_groups[i]['params'][0].stale_param.data + self.timers(OPTIMIZER_TRANSMIT_TIMER).start() + bit16_partitions[partition_id].data.copy_(fp32_partition.to(get_accelerator().current_device_name()).data) + self.timers(OPTIMIZER_TRANSMIT_TIMER).stop() + + see_memory_usage('After optimizer before all-gather') + if self.cpu_offload: + self.reset_cpu_buffers() + + self.timers(OPTIMIZER_ALLGATHER_TIMER).start() + # Gather the updated weights from everyone. + # Then all partitions of the model parameters are updated and ready for next round forward. + all_gather_dp_groups(groups_flat=self.bit16_groups_flat, + partitioned_param_groups=self.parallel_partitioned_bit16_groups, + dp_process_group=self.real_dp_process_group, + start_alignment_factor=self.nccl_start_alignment_factor, + allgather_bucket_size=self.allgather_bucket_size) + self.timers(OPTIMIZER_ALLGATHER_TIMER).stop() + + self.timers(OPTIMIZER_UPDATE_MODEL_TIMER).start() + # TODO: we probably don't need this? just to be safe + for i in range(len(self.bit16_groups)): + self._update_model_bit16_weights(i) + self.timers(OPTIMIZER_UPDATE_MODEL_TIMER).stop() + + self.timers.log(OPTIMIZER_TIMERS) + see_memory_usage('After zero_optimizer step') + + def zenflow_cpu_optimizer_step(self, now_state, scaled_global_grad_norm): + + if not self.process_optimizer_established: + self.start_optimizer_process() + + group_infos = [] + for group_no, group in enumerate(self.bit16_groups): + single_grad_partition = self.single_partition_of_fp32_groups[group_no].overlap_grad[now_state] + self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) + + group_info = { + "lr": self.optimizer.param_groups[group_no]["lr"], + "betas": self.optimizer.param_groups[group_no]["betas"], + "eps": self.optimizer.param_groups[group_no]["eps"], + "weight_decay": self.optimizer.param_groups[group_no]["weight_decay"], + "bias_correction": self.optimizer.param_groups[group_no]["bias_correction"], + } + + group_infos.append(group_info) + + self.parent_conn.send({ + "type": "step", + "now_state": now_state, + "micro_step": self.micro_step, + "group_infos": group_infos + }) + + def step(self, closure=None): + """ + Not supporting closure. + """ + self.micro_step_id = INITIAL_MICRO_STEP_ID + + see_memory_usage(f"In step before checking overflow") + + # First compute norm for all group so we know if there is overflow + if self.dtype == torch.float16: + self.check_overflow() + + self._update_scale(self.overflow) + if self.overflow: + see_memory_usage('After overflow before clearing gradients') + self.zero_grad(set_to_none=True) + if self.cpu_offload: + self.reset_cpu_buffers() + else: + self.averaged_gradients = {} + + see_memory_usage('After overflow after clearing gradients') + + for timer in OPTIMIZER_TIMERS: + self.timers(timer).start() + self.timers(timer).stop() + return + + prev_scale = self.loss_scale + # Step 1:- Calculate gradient norm using bit-16 grads + see_memory_usage('Before norm calculation') + scaled_global_grad_norm = self.scaled_global_norm() + self._global_grad_norm = scaled_global_grad_norm / prev_scale + see_memory_usage('After norm before optimizer') + + if self.micro_step < self.full_warm_up_rounds: + self.zenflow_cpu_optimizer_step(self.get_overlap_step_state(), scaled_global_grad_norm) + + self.wait_last_update_and_copy() + + if self.micro_step >= self.full_warm_up_rounds: + self.zenflow_cpu_optimizer_step(self.get_overlap_step_state(), scaled_global_grad_norm) + + return diff --git a/deepspeed/runtime/zero/zenflow/zenflow_utils.py b/deepspeed/runtime/zero/zenflow/zenflow_utils.py new file mode 100644 index 000000000000..a544aa3531bd --- /dev/null +++ b/deepspeed/runtime/zero/zenflow/zenflow_utils.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + + +def _flatten_dense_tensors(tensors): + """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of + same dense type. + + Since inputs are dense, the resulting tensor will be a concatenated 1D + buffer. Element-wise operation on this buffer will be equivalent to + operating individually. + + Args: + tensors (Iterable[Tensor]): dense tensors to flatten. + + Returns: + A contiguous 1D buffer containing input tensors. + """ + transposed_tensors = [t.transpose(0, 1).contiguous() if t.dim() == 2 else t for t in tensors] + return torch._C._nn.flatten_dense_tensors(transposed_tensors) + + +def _unflatten_dense_tensors(flat, tensors): + """View a flat buffer using the sizes of tensors. Assume that tensors are of + same dense type, and that flat is given by _flatten_dense_tensors. + + Args: + flat (Tensor): flattened dense tensors to unflatten. + tensors (Iterable[Tensor]): dense tensors whose sizes will be used to + unflatten flat. + + Returns: + Unflattened dense tensors with sizes same as tensors and values from + flat. + """ + transposed_tensors = [t.transpose(0, 1) if t.dim() == 2 else t for t in tensors] + unflat = torch._C._nn.unflatten_dense_tensors(flat, transposed_tensors) + return [t.transpose(0, 1) if t.dim() == 2 else t for t in unflat] From 4e9fe2a21c0276fb2bf65122ddb074c749434b65 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Wed, 25 Jun 2025 02:44:06 -0400 Subject: [PATCH 02/26] Add ZenFlowConfig for optimizer configuration - Define ZenFlowConfig with support for selective update parameters - Add validation for ZenFlow-related config fields Signed-off-by: Tingfeng Lan Co-authored-by: Yusen Wu --- deepspeed/runtime/zero/config.py | 5 ++- deepspeed/runtime/zero/offload_config.py | 55 +++++++++++++++++++++++- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 19ee9b51702e..5739f07cc535 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -9,7 +9,7 @@ from pydantic import Field, model_validator from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel from deepspeed.utils import logger -from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum +from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum, ZenFlowConfig # ZeRO optimization. By default, this optimization is not enabled. # Users have to configure the desired optimization (0 means disabled) in params.json as below example: @@ -165,6 +165,9 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): for :any:`DeepSpeedZeroOffloadOptimizerConfig`. """ + zenflow: Optional[ZenFlowConfig] = None + """Enable ZenFlow""" + sub_group_size: int = Field(pp_int(1e9), ge=0) """ Tile size for parameter processing to fit massive models (with trillions of diff --git a/deepspeed/runtime/zero/offload_config.py b/deepspeed/runtime/zero/offload_config.py index ca35d7a7d169..1400c09e7514 100644 --- a/deepspeed/runtime/zero/offload_config.py +++ b/deepspeed/runtime/zero/offload_config.py @@ -6,7 +6,7 @@ from enum import Enum from pathlib import Path from pydantic import Field, model_validator -from typing import Optional +from typing import Optional, Union from deepspeed.runtime.config_utils import DeepSpeedConfigModel, pp_int @@ -100,6 +100,59 @@ def set_pipeline(self): return self +class ZenFlowConfig(DeepSpeedConfigModel): + """Configuration options for ZenFlow optimization module.""" + + topk_ratio: float = Field(0.1, ge=0.0, le=1.0) + """Ratio of top-k important gradient columns to retain (range: 0.0 to 1.0).""" + + select_strategy: str = "auto" + """Strategy for selecting important gradient indices. + Options: "auto", "step", or "epoch".""" + + select_interval: Union[str, int] = "auto" + """Interval at which to reselect important gradient indices. + Can be "auto" or a fixed integer step/epoch interval.""" + + update_interval: Union[str, int] = "auto" + """Interval for applying accumulated unimportant gradients to model parameters. + Can be "auto" or a fixed integer step interval.""" + + overlap_step: bool = False + """Whether to overlap CPU-side optimizer steps with forward/backward computation.""" + + offload: bool = False + """Whether to offload selective optimizer states to CPU to save memory.""" + + auto_ratio: float = Field(0.99, ge=0.0, le=1.0) + """Threshold used in the "auto" strategy to determine update_interval.""" + + full_warm_up_rounds: int = 0 + """Number of initial rounds during which all gradients are fully updated (no selection).""" + + steps_per_epoch: Optional[int] = Field( + default=None, + description= + "Number of steps per epoch. This field is initialized during execution and should not be set by users.", + exclude=True) + + @model_validator(mode="after") + def validate_fields(self): + if self.select_strategy not in ["auto", "step", "epoch"]: + raise ValueError('select_strategy must be one of "auto", "step", or "epoch"') + + if isinstance(self.select_interval, str) and self.select_interval != "auto": + raise ValueError('If select_interval is a string, it must be "auto"') + + if isinstance(self.update_interval, str) and self.update_interval != "auto": + raise ValueError('If update_interval is a string, it must be "auto"') + + if not isinstance(self.full_warm_up_rounds, int): + raise ValueError('full_warm_up_rounds must be an integer') + + return self + + class OffloadStateTypeEnum(str, Enum): """ Enum for internal buffer types """ optim_states = "optim_states" From cac5703c55fa61b423e2cfd504856eb0e5698308 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Wed, 25 Jun 2025 02:44:36 -0400 Subject: [PATCH 03/26] Add ZenFlow (zero stage 1&2) integration in DeepSpeedEngine - Implement ZenFlow configuration and optimizer support in DeepSpeedEngine - Introduce methods for configuring ZenFlow parameters and handling selective updates - Enhance optimizer selection logic to accommodate ZenFlow optimizers - Update step function to manage ZenFlow-specific behaviors during training Signed-off-by: Tingfeng Lan Co-authored-by: Yusen Wu --- deepspeed/runtime/engine.py | 123 ++++++++++++++++++++++++++++++++++-- 1 file changed, 116 insertions(+), 7 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 428fc0baf43a..5714612e520c 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -27,6 +27,7 @@ from deepspeed.runtime.utils import see_memory_usage, DummyOptim from .zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer +from deepspeed.runtime.zero.zenflow.zenflow_stage_1_and_2 import ZenFlowZeroOptimizer from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload @@ -322,6 +323,8 @@ def __init__(self, if not isinstance(model_parameters, list): model_parameters = list(model_parameters) + self._configure_zenflow() + if has_optimizer: self._configure_optimizer(optimizer, model_parameters) self._configure_lr_scheduler() @@ -1030,6 +1033,9 @@ def swap_tensor_config(self): def aio_config(self): return self._config.aio_config + def zenflow_config(self): + return self._config.zero_config.zenflow + def get_data_types(self): model_dtype = torch.float32 if self.fp16_enabled(): @@ -1069,6 +1075,51 @@ def _configure_lr_scheduler(self): log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0]) + def _configure_zenflow(self): + + zenflow_config = self.zenflow_config() + if zenflow_config == None: + self.zenflow = False + return + + self.zenflow = True + select_strategy = zenflow_config.select_strategy + + if select_strategy == 'auto': + select_strategy = "epoch" + if isinstance(zenflow_config.select_interval, int): + raise Warning( + "If use auto select strategy, select_interval will be set to 1 and select_strategy will be set to epoch, thus select_interval would be overwritten." + ) + self.select_interval = 1 + else: + if isinstance(zenflow_config.select_interval, str): + raise ValueError("If don't use auto select strategy, select_interval must be a number.") + self.select_interval = zenflow_config.select_interval + + if isinstance(zenflow_config.update_interval, str): + self.auto_update = True + self.update_interval = 0 + else: + self.auto_update = False + self.update_interval = int(zenflow_config.update_interval) + + if select_strategy == 'epoch': + zenflow_config.steps_per_epoch = len(self.training_dataloader) + self.select_interval = self.select_interval * len(self.training_dataloader) + + if not self.auto_update and self.select_interval != 0 and self.select_interval < self.update_interval: + raise ValueError("Select interval must be greater or equal to update interval") + + self.overlap_step = zenflow_config.overlap_step + + self.full_warm_up_rounds = zenflow_config.full_warm_up_rounds + + self._config.gradient_accumulation_steps = self.update_interval + + def sync_selective_optimizer_lr(self): + self.optimizer._sync_selective_optimizer_lr() + def _configure_checkpointing(self, dist_init_required): self.checkpoint_engine = TorchCheckpointEngine() @@ -1449,10 +1500,14 @@ def _configure_basic_optimizer(self, model_parameters): optimizer = torch.optim.AdamW(model_parameters, **optimizer_parameters) else: if self.zero_use_cpu_optimizer(): - from deepspeed.ops.adam import DeepSpeedCPUAdam - optimizer = DeepSpeedCPUAdam(model_parameters, - **optimizer_parameters, - adamw_mode=effective_adam_w_mode) + from deepspeed.ops.adam import DeepSpeedCPUAdam, ZenFlowCPUAdam + CPUAdam = ZenFlowCPUAdam if self.zenflow else DeepSpeedCPUAdam + + zenflow_kwargs = {'overlap_step': self.overlap_step} if self.zenflow else {} + optimizer = CPUAdam(model_parameters, + **optimizer_parameters, + adamw_mode=effective_adam_w_mode, + **zenflow_kwargs) else: from deepspeed.ops.adam import FusedAdam @@ -1663,10 +1718,14 @@ def _configure_zero_optimizer(self, optimizer): if overlap_comm: logger.warning("Pipeline parallelism does not support overlapped communication, will be disabled.") overlap_comm = False - optimizer = DeepSpeedZeroOptimizer( + Stage1And2ZeroOptimizer = DeepSpeedZeroOptimizer if not self.zenflow else ZenFlowZeroOptimizer.create( + zenflow_config=self.zenflow_config()) + + optimizer = Stage1And2ZeroOptimizer( optimizer, self.param_names, timers=timers, + optimizer_params=self.optimizer_params(), static_loss_scale=self.loss_scale(), dynamic_loss_scale=self.dynamic_loss_scale(), dynamic_loss_args=self.dynamic_loss_scale_args(), @@ -1681,6 +1740,7 @@ def _configure_zero_optimizer(self, optimizer): reduce_scatter=self.zero_reduce_scatter(), overlap_comm=overlap_comm, offload_optimizer_config=self.zero_offload_optimizer(), + zenflow_config=self.zenflow_config(), mpu=self.mpu, postscale_gradients=self.postscale_gradients(), gradient_predivide_factor=self.gradient_predivide_factor(), @@ -2144,6 +2204,9 @@ def _backward_prologue(self, loss, scale_wrt_gas=True): if self.is_deepcompile_enabled(): deepcompile_backward_prologue(self.is_gradient_accumulation_boundary()) + if self.zenflow and self.auto_update: + self.optimizer.zenflow_state ^= 1 + return loss def _backward_epilogue(self): @@ -2230,8 +2293,21 @@ def is_gradient_accumulation_boundary(self): """ if self._is_gradient_accumulation_boundary is None: - return (self.micro_steps + 1) % \ - self.gradient_accumulation_steps() == 0 + if not self.zenflow: + return (self.micro_steps + 1) % \ + self.gradient_accumulation_steps() == 0 + elif not self.auto_update: + if (self.micro_steps + 1) < self.full_warm_up_rounds: + return True + else: + return ((self.micro_steps + 1 - self.full_warm_up_rounds) % self.gradient_accumulation_steps() == 0) \ + or (self.select_interval != 0 and (self.micro_steps + 1) % self.select_interval == 0) + else: + if (self.micro_steps + 1) <= self.full_warm_up_rounds: + return True + else: + return self.optimizer.zenflow_need_update[self.optimizer.zenflow_state ^ 1] \ + or (self.select_interval != 0 and (self.micro_steps + 1) % self.select_interval == 0) else: return self._is_gradient_accumulation_boundary @@ -2335,6 +2411,22 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}): self.global_steps += 1 self.global_samples += self.train_batch_size() + def _take_selective_parameter_step(self): + self.optimizer.selective_optimizer_step() + + def _take_lr_scheduler_step(self, lr_kwargs): + if self.lr_scheduler is not None: + try: + self.lr_scheduler.step(**(lr_kwargs or {})) + except TypeError: + # XXX Hack to work with Megatron 2.0 and DeepSpeed pipelines. + # We don't currently have a way to specify lr_kwargs from + # pipe_engine.train_batch() + self.lr_scheduler.step(self.train_batch_size()) + + def _log_selective_optimizer_timers(self): + self.optimizer.log_selective_optimizer_timers() + def step(self, lr_kwargs=None): r"""Execute the weight update step after forward and backward propagation on effective_train_batch. @@ -2358,6 +2450,11 @@ def step(self, lr_kwargs=None): self._step_applied = False # assume False, will flip to True + if self.zenflow: + self.sync_selective_optimizer_lr() + if self.auto_update: + self.update_interval += 1 + # Update the model when we reach gradient accumulation boundaries if self.is_gradient_accumulation_boundary(): self.gas_boundary_ctr += 1 @@ -2379,6 +2476,18 @@ def step(self, lr_kwargs=None): report_progress = self.global_rank == 0 if self.global_rank else True + if self.zenflow: + if not self.is_gradient_accumulation_boundary(): + self._take_lr_scheduler_step(lr_kwargs) + self._log_selective_optimizer_timers() + else: + if self.micro_steps + 1 >= self.full_warm_up_rounds: + self._take_selective_parameter_step() + if self.auto_update: + if dist.get_rank() == 0: + print(f"Zenflow: This is an update iter. update_interval: {self.update_interval}") + self.update_interval = 0 + self.tput_timer.stop(global_step=self.is_gradient_accumulation_boundary(), report_speed=report_progress) self._stop_timers(self.engine_timers.step_timers) From 0e9a0c9e13fe4c65141093a5eb20da26d004cd8c Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Wed, 25 Jun 2025 03:23:40 -0400 Subject: [PATCH 04/26] Add unit tests for ZenFlowConfig - Introduce tests to validate the behavior of DeepSpeedZeroConfig with various configurations for ZenFlowConfig, including stage enumeration and offload optimizer settings. - Ensure proper coercion of dictionary inputs into ZenFlowConfig and validate error handling for incorrect types. - Test combined usage of offload_optimizer and zenflow configurations under stage 2. Signed-off-by: Tingfeng Lan --- tests/unit/runtime/zenflow/test_zf_config.py | 85 ++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 tests/unit/runtime/zenflow/test_zf_config.py diff --git a/tests/unit/runtime/zenflow/test_zf_config.py b/tests/unit/runtime/zenflow/test_zf_config.py new file mode 100644 index 000000000000..c0811bc32423 --- /dev/null +++ b/tests/unit/runtime/zenflow/test_zf_config.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +from pydantic import ValidationError + +from deepspeed.runtime.zero.config import DeepSpeedZeroConfig, ZeroStageEnum +from deepspeed.runtime.zero.offload_config import DeepSpeedZeroOffloadOptimizerConfig, ZenFlowConfig + + +def test_stage_enum_accepts_int_and_enum(): + """`stage` can be passed as either an int or the ZeroStageEnum.""" + c1 = DeepSpeedZeroConfig(stage=2) + assert c1.stage == ZeroStageEnum.gradients + c2 = DeepSpeedZeroConfig(stage=ZeroStageEnum.weights) + assert c2.stage == ZeroStageEnum.weights + + +def test_offload_optimizer_config_from_dict(): + """A dict for offload_optimizer should be coerced into DeepSpeedZeroOffloadOptimizerConfig.""" + cfg = DeepSpeedZeroConfig(offload_optimizer={"device": "cpu", "pin_memory": True}) + assert isinstance(cfg.offload_optimizer, DeepSpeedZeroOffloadOptimizerConfig) + assert cfg.offload_optimizer.device == "cpu" + assert cfg.offload_optimizer.pin_memory is True + + +def test_invalid_offload_optimizer_type_raises(): + """Passing a non-dict to offload_optimizer must error out.""" + with pytest.raises(ValidationError): + DeepSpeedZeroConfig(offload_optimizer="not a dict") + + +def test_zenflow_config_from_dict(): + """A dict for zenflow should be coerced into ZenFlowConfig.""" + zenflow_payload = { + "topk_ratio": 0.25, + "select_strategy": "auto", + "select_interval": 4, + "update_interval": 8, + "full_warm_up_rounds": 1, + "overlap_step": True + } + cfg = DeepSpeedZeroConfig(zenflow=zenflow_payload) + assert isinstance(cfg.zenflow, ZenFlowConfig) + assert cfg.zenflow.topk_ratio == 0.25 + assert cfg.zenflow.select_strategy == "auto" + assert cfg.zenflow.select_interval == 4 + assert cfg.zenflow.update_interval == 8 + assert cfg.zenflow.full_warm_up_rounds == 1 + assert cfg.zenflow.overlap_step is True + + +def test_invalid_zenflow_type_raises(): + """Passing a non-dict to zenflow must error out.""" + with pytest.raises(ValidationError): + DeepSpeedZeroConfig(zenflow=123) + + +def test_offload_and_zenflow_combined(): + """ + offload_optimizer and zenflow can be used together under stage 2 + without validation errors. + """ + payload = { + "stage": 2, + "offload_optimizer": { + "device": "cpu", + "pin_memory": True + }, + "zenflow": { + "topk_ratio": 0.3, + "select_strategy": "epoch", + "select_interval": 3, + "update_interval": 6, + "full_warm_up_rounds": 0, + "overlap_step": False + } + } + cfg = DeepSpeedZeroConfig(**payload) + assert isinstance(cfg.offload_optimizer, DeepSpeedZeroOffloadOptimizerConfig) + assert cfg.offload_optimizer.device == "cpu" + assert isinstance(cfg.zenflow, ZenFlowConfig) + assert cfg.zenflow.select_strategy == "epoch" From 3353e34a56ec55ef09a14967c60c23275e91da34 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Thu, 26 Jun 2025 15:04:19 +0000 Subject: [PATCH 05/26] Fix initialization and update logic for ZenFlow optimizers - Fix initialization logic for ZenFlowCPUAdam - Fix gradient update issues in ZenFlowSelectiveAdamW Signed-off-by: Tingfeng Lan Signed-off-by: Yusen Wu Co-authored-by: Yusen Wu --- deepspeed/ops/adam/zenflow_cpu_adam.py | 2 +- deepspeed/ops/adam/zenflow_torch_adam.py | 176 +++-------------------- 2 files changed, 17 insertions(+), 161 deletions(-) diff --git a/deepspeed/ops/adam/zenflow_cpu_adam.py b/deepspeed/ops/adam/zenflow_cpu_adam.py index dedebc53559f..b4774bf4289d 100644 --- a/deepspeed/ops/adam/zenflow_cpu_adam.py +++ b/deepspeed/ops/adam/zenflow_cpu_adam.py @@ -12,7 +12,7 @@ class ZenFlowCPUAdam(DeepSpeedCPUAdam): def __init__(self, *args, overlap_step=False, **kwargs): super(ZenFlowCPUAdam, self).__init__(*args, **kwargs) self.overlap_step = overlap_step - if self.overlap_step: + if not self.overlap_step: print("ZenFlowCPUAdam initialized with overlap step.") self.step = self._sequential_step else: diff --git a/deepspeed/ops/adam/zenflow_torch_adam.py b/deepspeed/ops/adam/zenflow_torch_adam.py index f682427ec1de..25297f346f02 100644 --- a/deepspeed/ops/adam/zenflow_torch_adam.py +++ b/deepspeed/ops/adam/zenflow_torch_adam.py @@ -23,6 +23,8 @@ class ZenFlowSelectiveAdamW(torch.optim.AdamW): def __init__(self, *args, offload=False, bucket_size=5e8, **kwargs): super(ZenFlowSelectiveAdamW, self).__init__(*args, **kwargs) + + self.offload = offload if offload: self.step = self._step_with_offload @@ -128,6 +130,11 @@ def _step_without_offload(self): maximize=False, ) + for i, param in enumerate(group["params"]): + if hasattr(param, "selected_grad"): + if len(param.shape) != 1: + param.data[:, param.selected_indices] = params_with_grad[i] + for param in group["params"]: if hasattr(param, "temp_selected_param"): param.temp_selected_param = None @@ -214,6 +221,11 @@ def _group_step_without_offload(self, group_to_paramlist): maximize=False, ) + for i, param in enumerate(params): + if hasattr(param, "selected_grad"): + if len(param.shape) != 1: + param.data[:, param.selected_indices] = params_with_grad[i] + for param in params: param.selected_grad = None @@ -266,168 +278,12 @@ def _group_step_with_offload(self, group_to_paramlist): maximize=False, ) - self.copy_mv_to_cpu(params) - - for param in params: - param.selected_grad = None - - -class ZenFlowSelectiveAdamW_stage3(torch.optim.AdamW): - - def __init__(self, *args, **kwargs): - super(ZenFlowSelectiveAdamW_stage3, self).__init__(*args, **kwargs) - - @torch.no_grad() - def temp_copy_param(self, paramlist): - for param in paramlist: - if hasattr(param, "selected_grad"): - num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) - - if num_row != 1: - param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, param.complete_numel).view( - param.complete_numel // num_row, num_row) - param.temp_selected_param = param_2d[param.selected_indices, :].clone().detach() - else: - param.temp_selected_param = param.ds_tensor.data.clone().detach() - - def clear_selected_mv(self): - print("clearing...") - for group in self.param_groups: - for param in group['params']: - state = self.state.setdefault(param, {}) - if len(state) == 0: - continue - state["exp_avg"].zero_() - state["exp_avg_sq"].zero_() - - @torch.no_grad() - def step(self): - for group in self.param_groups: - - params_with_grad: List[Tensor] = [] - grads: List[Tensor] = [] - exp_avgs: List[Tensor] = [] - exp_avg_sqs: List[Tensor] = [] - max_exp_avg_sqs: List[Tensor] = [] - state_steps: List[Tensor] = [] - amsgrad: bool = group["amsgrad"] - beta1, beta2 = cast(Tuple[float, float], group["betas"]) - - for param in group["params"]: + for i, param in enumerate(params): if hasattr(param, "selected_grad"): - num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) - if num_row != 1: - param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, - param.complete_numel).view( - param.complete_numel // num_row, num_row) - selected_param = param_2d[param.selected_indices, :] - else: - selected_param = param.ds_tensor.data - if hasattr(param, 'temp_selected_param') and param.temp_selected_param is not None: - selected_param.copy_(param.temp_selected_param) - - state = self.state.setdefault(param, {}) - if len(state) == 0: - state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device) - state["exp_avg"] = torch.zeros_like(selected_param) - state["exp_avg_sq"] = torch.zeros_like(selected_param) - if amsgrad: - state["max_exp_avg_sq"] = torch.zeros_like(selected_param) + if len(param.shape) != 1: + param.data[:, param.selected_indices] = params_with_grad[i] - params_with_grad.append(selected_param) - grads.append(param.selected_grad) - exp_avgs.append(state["exp_avg"]) - exp_avg_sqs.append(state["exp_avg_sq"]) - if amsgrad: - max_exp_avg_sqs.append(state["max_exp_avg_sq"]) - state_steps.append(state["step"]) - adamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=False, - ) - - for param in group["params"]: - if hasattr(param, "temp_selected_param"): - param.temp_selected_param = None - param.selected_grad = None - - @torch.no_grad() - def group_step(self, paramlist): - - group_to_paramlist = {} - for param in paramlist: - group_id = param.group_id - if group_id not in group_to_paramlist: - group_to_paramlist[group_id] = [] - group_to_paramlist[group_id].append(param) - - for group_id in sorted(group_to_paramlist.keys()): - params = group_to_paramlist[group_id] - group = self.param_groups[group_id] - - params_with_grad: List[Tensor] = [] - grads: List[Tensor] = [] - exp_avgs: List[Tensor] = [] - exp_avg_sqs: List[Tensor] = [] - max_exp_avg_sqs: List[Tensor] = [] - state_steps: List[Tensor] = [] - amsgrad: bool = group["amsgrad"] - beta1, beta2 = cast(Tuple[float, float], group["betas"]) - - for param in params: - if hasattr(param, "selected_grad"): - num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) - - if num_row != 1: - param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, - param.complete_numel).view( - param.complete_numel // num_row, num_row) - selected_param = param_2d[param.selected_indices, :] - else: - selected_param = param.ds_tensor.data - - state = self.state.setdefault(param, {}) - if len(state) == 0: - state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device) - state["exp_avg"] = torch.zeros_like(selected_param) - state["exp_avg_sq"] = torch.zeros_like(selected_param) - if amsgrad: - state["max_exp_avg_sq"] = torch.zeros_like(selected_param) - - params_with_grad.append(selected_param) - grads.append(param.selected_grad) - exp_avgs.append(state["exp_avg"]) - exp_avg_sqs.append(state["exp_avg_sq"]) - if amsgrad: - max_exp_avg_sqs.append(state["max_exp_avg_sq"]) - state_steps.append(state["step"]) - - adamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=False, - ) + self.copy_mv_to_cpu(params) for param in params: param.selected_grad = None From 28cdf89edfec1fcea96292bc979b37a5b8e9af19 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Thu, 26 Jun 2025 15:11:04 +0000 Subject: [PATCH 06/26] Add unit tests for ZenFlowSelectiveAdamW optimizer - Introduce tests for ZenFlowSelectiveAdamW covering both offload and non-offload modes. - Validate step and group_step behavior with selected index updates and temporary parameter storage. - Ensure correct handling of 1D and 2D parameters, as well as proper gradient/state cleanup after updates. - Verify state increment logic and compatibility with PyTorch's native AdamW for numerical correctness. Signed-off-by: Tingfeng Lan Signed-off-by: Yusen Wu Co-authored-by: Yusen Wu --- tests/unit/ops/adam/test_zf_torch_adam.py | 153 ++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 tests/unit/ops/adam/test_zf_torch_adam.py diff --git a/tests/unit/ops/adam/test_zf_torch_adam.py b/tests/unit/ops/adam/test_zf_torch_adam.py new file mode 100644 index 000000000000..4db91e8d1c41 --- /dev/null +++ b/tests/unit/ops/adam/test_zf_torch_adam.py @@ -0,0 +1,153 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import numpy as np +from torch.nn import Parameter +from deepspeed.ops.adam import ZenFlowSelectiveAdamW + + +def make_param(shape, selected_indices=None): + param = Parameter(torch.randn(*shape)) + if selected_indices is not None: + param.selected_indices = selected_indices + param.selected_grad = torch.randn(param.shape[0], len(selected_indices)) + param.temp_selected_param = param.data[:, selected_indices].clone() + return param + + +def test_init_methods(): + opt1 = ZenFlowSelectiveAdamW([torch.nn.Parameter(torch.randn(2, 4))], lr=1e-3, offload=False) + assert opt1.step == opt1._step_without_offload + assert opt1.group_step == opt1._group_step_without_offload + opt2 = ZenFlowSelectiveAdamW([torch.nn.Parameter(torch.randn(2, 4))], lr=1e-3, offload=True) + assert opt2.step == opt2._step_with_offload + assert opt2.group_step == opt2._group_step_with_offload + + +def test_step_without_offload(): + param = make_param((4, 6), torch.tensor([1, 3, 4])) + param.requires_grad_(True) + opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False) + + old_selected = param.data[:, param.selected_indices].clone() + + opt.step() + + new_selected = param.data[:, param.selected_indices] + diff_norm = (old_selected - new_selected).abs().sum().item() + + assert diff_norm > 1e-5, "param was not updated" + assert param.temp_selected_param is None + assert param.selected_grad is None + + +def test_step_with_offload_bucket_flush(): + param1 = make_param((2, 4), torch.tensor([1, 2])) + param2 = make_param((2, 4), torch.tensor([0, 3])) + + param1.exp_avg = torch.zeros_like(param1.temp_selected_param) + param1.exp_avg_sq = torch.zeros_like(param1.temp_selected_param) + param1.exp_avg_cpu_data = param1.exp_avg.clone().cpu() + param1.exp_avg_sq_cpu_data = param1.exp_avg_sq.clone().cpu() + + param2.exp_avg = torch.zeros_like(param2.temp_selected_param) + param2.exp_avg_sq = torch.zeros_like(param2.temp_selected_param) + param2.exp_avg_cpu_data = param2.exp_avg.clone().cpu() + param2.exp_avg_sq_cpu_data = param2.exp_avg_sq.clone().cpu() + + opt = ZenFlowSelectiveAdamW([param1, param2], lr=1e-3, offload=True, bucket_size=1) + opt.step() + assert param1.temp_selected_param is None + assert param2.temp_selected_param is None + + +def test_clear_selected_mv(): + param = make_param((2, 4), torch.tensor([0, 2])) + opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False) + opt.step() + state = opt.state[param] + assert "exp_avg" in state + opt.clear_selected_mv() + assert state["exp_avg"].abs().sum() == 0 + + +def test_group_step_without_offload(): + param = make_param((2, 6), torch.tensor([0, 1, 3])) + opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False) + group_to_paramlist = {0: [param]} + opt._group_step_without_offload(group_to_paramlist) + assert param.selected_grad is None + + +def test_group_step_with_offload(): + param = make_param((2, 6), torch.tensor([0, 1, 3])) + opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=True) + + state = opt.state.setdefault(param, {}) + state["step"] = torch.zeros((), dtype=param.dtype, device=param.device) + param.exp_avg = torch.zeros_like(param.data[:, param.selected_indices]) + param.exp_avg_sq = torch.zeros_like(param.data[:, param.selected_indices]) + param.exp_avg_cpu_data = param.exp_avg.clone().cpu() + param.exp_avg_sq_cpu_data = param.exp_avg_sq.clone().cpu() + + group_to_paramlist = {0: [param]} + opt._group_step_with_offload(group_to_paramlist) + assert param.selected_grad is None + +def test_1d_param_support(): + param = Parameter(torch.randn(10)) + param.selected_grad = torch.randn(10) + param.temp_selected_param = param.data.clone() + opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False) + opt.step() + assert param.temp_selected_param is None + assert param.selected_grad is None + +def test_state_increment(): + param = torch.nn.Parameter(torch.randn(2, 4)) + param.selected_indices = torch.arange(4) + param.selected_grad = torch.randn(2, 4) + param.temp_selected_param = param.data.clone() + + opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False) + opt.step() + step1 = opt.state[param]['step'].item() + + param.selected_grad = torch.randn(2, 4) + param.temp_selected_param = param.data.clone() + param.selected_indices = torch.arange(4) + + opt.step() + step2 = opt.state[param]['step'].item() + assert step2 == step1 + 1 + +def _compare_with_torch_adamw(param, zenflow_opt, atol=1e-4): + torch_param = torch.nn.Parameter(param.detach().clone()) + torch_opt = torch.optim.AdamW([torch_param], lr=zenflow_opt.param_groups[0]['lr']) + + for _ in range(10): + grad = torch.randn_like(param) + param.selected_indices = torch.arange(param.shape[1]) + param.selected_grad = grad + param.temp_selected_param = param.data.clone() + + torch_param.grad = grad.clone() + + zenflow_opt.step() + torch_opt.step() + + np.testing.assert_allclose( + torch_param.data.cpu().numpy(), + param.data.cpu().numpy(), + atol=atol, + err_msg="Mismatch with torch.AdamW" + ) + +def test_against_torch_adamw(): + param = torch.nn.Parameter(torch.randn(2, 4)) + param.selected_indices = torch.arange(4) + opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False) + _compare_with_torch_adamw(param, opt) \ No newline at end of file From f534d5e308f5dc518824b5f411e5598803ee940a Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Thu, 26 Jun 2025 20:21:59 -0400 Subject: [PATCH 07/26] Add ZenFlow tutorial documentation - Introduce a new tutorial for ZenFlow, detailing its configuration and usage in DeepSpeed. Signed-off-by: Tingfeng Lan Co-authored-by: Yusen Wu --- docs/_tutorials/zenflow.md | 69 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 docs/_tutorials/zenflow.md diff --git a/docs/_tutorials/zenflow.md b/docs/_tutorials/zenflow.md new file mode 100644 index 000000000000..9bb24281d4a5 --- /dev/null +++ b/docs/_tutorials/zenflow.md @@ -0,0 +1,69 @@ +--- +title: "ZenFlow" +tags: training +--- + +ZenFlow is an extension of ZeRO-Offload that decouples and asynchronously updates gradients during training. It reduces CPU-induced stalls when using offload optimizers, enabling smoother and faster training. Like ZeRO-Offload, ZenFlow requires no code changes—only configuration updates in your DeepSpeed JSON file. + +We recommend that you read the tutorials on [Getting Started](/getting-started/) and [ZeRO](/tutorials/zero/) before stepping through this tutorial. ZenFlow builds on top of [ZeRO-Offload](/tutorials/zero-offload/), so shared setup details can be found there. + +## Configuration Changes + +To enable ZenFlow, simply add a `zenflow` section under the existing `zero_optimization` block in your DeepSpeed config: + +```json +{ + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "zenflow": { + "topk_ratio": 0.05, + "select_strategy": "auto", + "select_interval": "auto", + "update_interval": 4, + "full_warm_up_rounds": 0, + "overlap_step": true + } + } +} +``` + + +Each field in the `zenflow` block controls selective gradient update behavior: + +- `topk_ratio`: Fraction of the most important gradients to update on GPU (e.g., 0.05 means top 5% by importance). +- `select_strategy`: Strategy for selecting important gradients (`"auto"`, `"step"`, or custom). +- `select_interval`: How often to re-select important gradients (`"auto"` or integer like 1). +- `update_interval`: How often to update unimportant gradients (`"auto"` or an integer like 4, meaning every 4 steps). +- `full_warm_up_rounds`: Number of initial steps with full gradient updates before selection begins. +- `overlap_step`: Whether to overlap communication with computation (`true` enables it). + +--- + +**Recommended**: Use `"auto"` for `select_strategy`, `select_interval`, and `update_interval` to enable adaptive behavior with minimal tuning. + +You can continue using the same training setup and launch script as in the [ZeRO-Offload tutorial](/tutorials/zero-offload/), since ZenFlow builds directly on top of ZeRO Offload. + +## Quick Start: Fine-tuning Example + +A complete fine-tuning example using ZenFlow is available in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) -- [ZenFlow Fine-Tuning on GLUE](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/DeepSpeed-ZenFlow) + +This example shows how to fine-tune a GPT model on the GLUE benchmark with: +- CPU optimizer offload +- ZenFlow asynchronous updates + +To run the example: + +```bash +cd DeepSpeedExamples/training/DeepSpeed-ZenFlow +bash finetune_gpt_glue.sh +``` + +Refer to the `README.md` in the folder for setup instructions, dataset preparation, and configuration details. + +--- + +Congratulations! You have successfully enabled ZenFlow for stall-free offloading. From 80ad488918efd20a69dd12de99f00c04388c663f Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Thu, 26 Jun 2025 20:38:18 -0400 Subject: [PATCH 08/26] Format code Signed-off-by: Tingfeng Lan --- deepspeed/ops/adam/__init__.py | 2 +- deepspeed/ops/adam/zenflow_torch_adam.py | 2 +- .../zero/zenflow/zenflow_stage_1_and_2.py | 6 +++++- tests/unit/ops/adam/test_zf_torch_adam.py | 18 ++++++++++-------- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/deepspeed/ops/adam/__init__.py b/deepspeed/ops/adam/__init__.py index 039106a1fb84..82dfa114ac9a 100755 --- a/deepspeed/ops/adam/__init__.py +++ b/deepspeed/ops/adam/__init__.py @@ -6,4 +6,4 @@ from .cpu_adam import DeepSpeedCPUAdam from .fused_adam import FusedAdam from .zenflow_cpu_adam import ZenFlowCPUAdam -from .zenflow_torch_adam import ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3 +from .zenflow_torch_adam import ZenFlowSelectiveAdamW diff --git a/deepspeed/ops/adam/zenflow_torch_adam.py b/deepspeed/ops/adam/zenflow_torch_adam.py index 25297f346f02..af30698a0327 100644 --- a/deepspeed/ops/adam/zenflow_torch_adam.py +++ b/deepspeed/ops/adam/zenflow_torch_adam.py @@ -23,7 +23,7 @@ class ZenFlowSelectiveAdamW(torch.optim.AdamW): def __init__(self, *args, offload=False, bucket_size=5e8, **kwargs): super(ZenFlowSelectiveAdamW, self).__init__(*args, **kwargs) - + self.offload = offload if offload: diff --git a/deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py index c1489848cb13..022736c91c1e 100644 --- a/deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py +++ b/deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py @@ -109,8 +109,12 @@ def __init__(self, @classmethod def create(cls, zenflow_config): if zenflow_config.overlap_step: + # print("Yes!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print("Yes!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!ZenFlowZeroOptimizerParallel") return ZenFlowZeroOptimizerParallel else: + # print("No!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print("No!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!ZenFlowZeroOptimizerSequential") return ZenFlowZeroOptimizerSequential def _configure_zenflow(self, zenflow_config): @@ -284,7 +288,7 @@ def process_selected_fp32_groups_grad(self, tensor, total_size): param: The parameter to process param_id: ID of the parameter """ - + print("Yes!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!process_selected_fp32_groups_grad") curr_size = 0 curr_grad_buffer_size = 0 curr_sum_buffer_size = 0 diff --git a/tests/unit/ops/adam/test_zf_torch_adam.py b/tests/unit/ops/adam/test_zf_torch_adam.py index 4db91e8d1c41..76915d1bb32c 100644 --- a/tests/unit/ops/adam/test_zf_torch_adam.py +++ b/tests/unit/ops/adam/test_zf_torch_adam.py @@ -52,7 +52,7 @@ def test_step_with_offload_bucket_flush(): param1.exp_avg_sq = torch.zeros_like(param1.temp_selected_param) param1.exp_avg_cpu_data = param1.exp_avg.clone().cpu() param1.exp_avg_sq_cpu_data = param1.exp_avg_sq.clone().cpu() - + param2.exp_avg = torch.zeros_like(param2.temp_selected_param) param2.exp_avg_sq = torch.zeros_like(param2.temp_selected_param) param2.exp_avg_cpu_data = param2.exp_avg.clone().cpu() @@ -97,6 +97,7 @@ def test_group_step_with_offload(): opt._group_step_with_offload(group_to_paramlist) assert param.selected_grad is None + def test_1d_param_support(): param = Parameter(torch.randn(10)) param.selected_grad = torch.randn(10) @@ -106,6 +107,7 @@ def test_1d_param_support(): assert param.temp_selected_param is None assert param.selected_grad is None + def test_state_increment(): param = torch.nn.Parameter(torch.randn(2, 4)) param.selected_indices = torch.arange(4) @@ -124,6 +126,7 @@ def test_state_increment(): step2 = opt.state[param]['step'].item() assert step2 == step1 + 1 + def _compare_with_torch_adamw(param, zenflow_opt, atol=1e-4): torch_param = torch.nn.Parameter(param.detach().clone()) torch_opt = torch.optim.AdamW([torch_param], lr=zenflow_opt.param_groups[0]['lr']) @@ -139,15 +142,14 @@ def _compare_with_torch_adamw(param, zenflow_opt, atol=1e-4): zenflow_opt.step() torch_opt.step() - np.testing.assert_allclose( - torch_param.data.cpu().numpy(), - param.data.cpu().numpy(), - atol=atol, - err_msg="Mismatch with torch.AdamW" - ) + np.testing.assert_allclose(torch_param.data.cpu().numpy(), + param.data.cpu().numpy(), + atol=atol, + err_msg="Mismatch with torch.AdamW") + def test_against_torch_adamw(): param = torch.nn.Parameter(torch.randn(2, 4)) param.selected_indices = torch.arange(4) opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False) - _compare_with_torch_adamw(param, opt) \ No newline at end of file + _compare_with_torch_adamw(param, opt) From 9c05ccba281a6d30e642bd714c28a60f9bb96a4b Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Thu, 26 Jun 2025 21:23:44 -0400 Subject: [PATCH 09/26] Fix check_grad_overflow parameter in ZenFlowZeroOptimizer Signed-off-by: Tingfeng Lan --- deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py index 022736c91c1e..b487ed5fd298 100644 --- a/deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py +++ b/deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py @@ -79,7 +79,8 @@ def __init__(self, round_robin_gradients=False, has_moe_layers=False, fp16_master_weights_and_gradients=False, - elastic_checkpoint=False): + elastic_checkpoint=False, + check_grad_overflow=True): super().__init__(init_optimizer, param_names, timers, optimizer_params, static_loss_scale, dynamic_loss_scale, dynamic_loss_args, verbose, contiguous_gradients, reduce_bucket_size, From da80ff758fba55b03c591f1a134fe258ba58a065 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Thu, 26 Jun 2025 23:37:50 -0400 Subject: [PATCH 10/26] Refactor ZenFlowZeroOptimizer methods to include communication data type - Updated methods to accept communication_data_type as a parameter for better handling of IPG buckets. - Removed debug print statements to clean up the code. Signed-off-by: Tingfeng Lan Co-authored-by: Yusen Wu --- .../zero/zenflow/zenflow_stage_1_and_2.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py index b487ed5fd298..25dc44a288ec 100644 --- a/deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py +++ b/deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py @@ -110,12 +110,8 @@ def __init__(self, @classmethod def create(cls, zenflow_config): if zenflow_config.overlap_step: - # print("Yes!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - print("Yes!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!ZenFlowZeroOptimizerParallel") return ZenFlowZeroOptimizerParallel else: - # print("No!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - print("No!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!ZenFlowZeroOptimizerSequential") return ZenFlowZeroOptimizerSequential def _configure_zenflow(self, zenflow_config): @@ -182,7 +178,7 @@ def sync_fp32_param_from_gpu(self): fp32_partition.copy_(bit16_partitions[partition_id].to(dtype=fp32_partition.dtype, device=fp32_partition.device)) - def update_selected_channels(self, tensor, total_size): + def update_selected_channels(self, tensor, total_size, communication_data_type): curr_size = 0 curr_index_buffer_size = 0 rank_and_offsets = [] @@ -194,7 +190,8 @@ def update_selected_channels(self, tensor, total_size): self.index_buffer = torch.empty(total_size, dtype=torch.int32, device='cuda') # count = 0 - for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: + bucket = self.ipg_buckets[communication_data_type] + for i, param_idx_in_group, param_id in bucket.params: param = self.bit16_groups[i][param_idx_in_group] if len(param.shape) == 1: @@ -255,7 +252,7 @@ def update_selected_channels(self, tensor, total_size): index_slice = self.index_buffer.narrow(0, offset, num_select) dist.broadcast(index_slice, src=src_rank, group=process_group) - for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: + for i, param_idx_in_group, param_id in bucket.params: param = self.bit16_groups[i][param_idx_in_group] if len(param.shape) == 1: @@ -281,7 +278,7 @@ def update_selected_channels(self, tensor, total_size): self.index_buffer = None - def process_selected_fp32_groups_grad(self, tensor, total_size): + def _process_selected_fp32_groups_grad(self, tensor, total_size, communication_data_type): """ Process gradients for selected columns in FP32 groups @@ -289,7 +286,7 @@ def process_selected_fp32_groups_grad(self, tensor, total_size): param: The parameter to process param_id: ID of the parameter """ - print("Yes!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!process_selected_fp32_groups_grad") + curr_size = 0 curr_grad_buffer_size = 0 curr_sum_buffer_size = 0 @@ -309,7 +306,8 @@ def process_selected_fp32_groups_grad(self, tensor, total_size): group_to_paramlist = {} # count = 0 - for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: + bucket = self.ipg_buckets[communication_data_type] + for i, param_idx_in_group, param_id in bucket.params: param = self.bit16_groups[i][param_idx_in_group] if not hasattr(param, 'selected_indices'): @@ -389,7 +387,7 @@ def process_selected_fp32_groups_grad(self, tensor, total_size): sum_slice = self.sum_buffer.narrow(0, sum_offset, sum_num) dist.broadcast(sum_slice, src=src_rank, group=process_group) - for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: + for i, param_idx_in_group, param_id in bucket.params: param = self.bit16_groups[i][param_idx_in_group] selected_grad = None @@ -450,7 +448,7 @@ def process_selected_fp32_groups_grad(self, tensor, total_size): if self.auto_update: self.sum_buffer = None - def average_tensor(self, tensor): + def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dtype): if self.overlap_comm: stream = self.reduction_stream if not get_accelerator().resolves_data_dependency(): @@ -478,12 +476,13 @@ def average_tensor(self, tensor): process_group = self.dp_process_group # count = 0 - for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: + bucket = self.ipg_buckets[communication_data_type] + for i, param_idx_in_group, param_id in bucket.params: param = self.bit16_groups[i][param_idx_in_group] process_group = self.dp_process_group - if self.ipg_bucket_has_moe_params: + if bucket.has_moe_params: process_group = self.expert_dp_process_group[param.group_name] if is_moe_param( param) else self.dp_process_group @@ -546,12 +545,14 @@ def average_tensor(self, tensor): for bucket_key in buckets: if self.use_multi_rank_bucket_allreduce: self.allreduce_and_scatter(buckets[bucket_key], + communication_data_type, numel_per_bucket=self.reduce_bucket_size, divide=False, process_group=bucket_key) else: dst, process_group = bucket_key self.allreduce_no_retain(buckets[bucket_key], + communication_data_type, numel_per_bucket=self.reduce_bucket_size, rank=dst, divide=False, @@ -560,7 +561,7 @@ def average_tensor(self, tensor): if self.is_zenflow_select_boundary(): self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).start() # print("update selected") - self.update_selected_channels(tensor, curr_column_size) + self.update_selected_channels(tensor, curr_column_size, communication_data_type) self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).stop() elif self.zenflow: self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).start() @@ -568,7 +569,7 @@ def average_tensor(self, tensor): if self.zenflow and self.micro_step >= self.full_warm_up_rounds: self.timers(SELECTIVE_OPTIMIZER_PROCESS_TIMER).start() - self.process_selected_fp32_groups_grad(tensor, curr_selected_reduce_size) + self._process_selected_fp32_groups_grad(tensor, curr_selected_reduce_size, communication_data_type) self.timers(SELECTIVE_OPTIMIZER_PROCESS_TIMER).stop() def backward(self, loss, retain_graph=False): From fee24ffbd86a1b6c23f9946b029b9c796e45ddab Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Fri, 27 Jun 2025 23:42:18 -0400 Subject: [PATCH 11/26] Refactor ZenFlow integration in DeepSpeedEngine - Move `_configure_zenflow` logic to a standalone `configure_zenflow()` function in `zenflow_utils.py` - Refactor ZenFlow place to decouple it from ZeRO internals Signed-off-by: Tingfeng Lan --- deepspeed/runtime/engine.py | 5 +- .../runtime/{zero => }/zenflow/__init__.py | 0 deepspeed/runtime/zenflow/zenflow_config.py | 62 +++++++++++++ .../zenflow/zenflow_stage_1_and_2.py | 0 deepspeed/runtime/zenflow/zenflow_utils.py | 88 +++++++++++++++++++ deepspeed/runtime/zero/config.py | 3 +- deepspeed/runtime/zero/offload_config.py | 55 +----------- deepspeed/runtime/zero/stage_1_and_2.py | 4 +- .../runtime/zero/zenflow/zenflow_utils.py | 42 --------- tests/unit/runtime/zenflow/test_zf_config.py | 3 +- 10 files changed, 162 insertions(+), 100 deletions(-) rename deepspeed/runtime/{zero => }/zenflow/__init__.py (100%) create mode 100644 deepspeed/runtime/zenflow/zenflow_config.py rename deepspeed/runtime/{zero => }/zenflow/zenflow_stage_1_and_2.py (100%) create mode 100644 deepspeed/runtime/zenflow/zenflow_utils.py delete mode 100644 deepspeed/runtime/zero/zenflow/zenflow_utils.py diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 546c0e6384ed..0699c5b665f0 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -27,11 +27,12 @@ from deepspeed.runtime.utils import see_memory_usage, DummyOptim from .zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer -from deepspeed.runtime.zero.zenflow.zenflow_stage_1_and_2 import ZenFlowZeroOptimizer +from deepspeed.runtime.zenflow.zenflow_stage_1_and_2 import ZenFlowZeroOptimizer from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION +from deepspeed.runtime.zenflow.zenflow_utils import configure_zenflow from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer @@ -334,6 +335,8 @@ def __init__(self, if self.torch_autocast_enabled(): init_autocast_params(self, self.torch_autocast_dtype(), self.torch_autocast_lower_precision_safe_modules()) + configure_zenflow(self) + if has_optimizer: self._configure_optimizer(optimizer, model_parameters) self._configure_lr_scheduler() diff --git a/deepspeed/runtime/zero/zenflow/__init__.py b/deepspeed/runtime/zenflow/__init__.py similarity index 100% rename from deepspeed/runtime/zero/zenflow/__init__.py rename to deepspeed/runtime/zenflow/__init__.py diff --git a/deepspeed/runtime/zenflow/zenflow_config.py b/deepspeed/runtime/zenflow/zenflow_config.py new file mode 100644 index 000000000000..416607c6b387 --- /dev/null +++ b/deepspeed/runtime/zenflow/zenflow_config.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from pydantic import Field, model_validator +from typing import Optional, Union + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel + + +class ZenFlowConfig(DeepSpeedConfigModel): + """Configuration options for ZenFlow optimization module.""" + + topk_ratio: float = Field(0.1, ge=0.0, le=1.0) + """Ratio of top-k important gradient columns to retain (range: 0.0 to 1.0).""" + + select_strategy: str = "auto" + """Strategy for selecting important gradient indices. + Options: "auto", "step", or "epoch".""" + + select_interval: Union[str, int] = "auto" + """Interval at which to reselect important gradient indices. + Can be "auto" or a fixed integer step/epoch interval.""" + + update_interval: Union[str, int] = "auto" + """Interval for applying accumulated unimportant gradients to model parameters. + Can be "auto" or a fixed integer step interval.""" + + overlap_step: bool = False + """Whether to overlap CPU-side optimizer steps with forward/backward computation.""" + + offload: bool = False + """Whether to offload selective optimizer states to CPU to save memory.""" + + auto_ratio: float = Field(0.99, ge=0.0, le=1.0) + """Threshold used in the "auto" strategy to determine update_interval.""" + + full_warm_up_rounds: int = 0 + """Number of initial rounds during which all gradients are fully updated (no selection).""" + + steps_per_epoch: Optional[int] = Field( + default=None, + description= + "Number of steps per epoch. This field is initialized during execution and should not be set by users.", + exclude=True) + + @model_validator(mode="after") + def validate_fields(self): + if self.select_strategy not in ["auto", "step", "epoch"]: + raise ValueError('select_strategy must be one of "auto", "step", or "epoch"') + + if isinstance(self.select_interval, str) and self.select_interval != "auto": + raise ValueError('If select_interval is a string, it must be "auto"') + + if isinstance(self.update_interval, str) and self.update_interval != "auto": + raise ValueError('If update_interval is a string, it must be "auto"') + + if not isinstance(self.full_warm_up_rounds, int): + raise ValueError('full_warm_up_rounds must be an integer') + + return self diff --git a/deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py similarity index 100% rename from deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py rename to deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py diff --git a/deepspeed/runtime/zenflow/zenflow_utils.py b/deepspeed/runtime/zenflow/zenflow_utils.py new file mode 100644 index 000000000000..4be8e6f2f894 --- /dev/null +++ b/deepspeed/runtime/zenflow/zenflow_utils.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from deepspeed.runtime.engine import DeepSpeedEngine + + +def _flatten_dense_tensors(tensors): + """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of + same dense type. + + Since inputs are dense, the resulting tensor will be a concatenated 1D + buffer. Element-wise operation on this buffer will be equivalent to + operating individually. + + Args: + tensors (Iterable[Tensor]): dense tensors to flatten. + + Returns: + A contiguous 1D buffer containing input tensors. + """ + transposed_tensors = [t.transpose(0, 1).contiguous() if t.dim() == 2 else t for t in tensors] + return torch._C._nn.flatten_dense_tensors(transposed_tensors) + + +def _unflatten_dense_tensors(flat, tensors): + """View a flat buffer using the sizes of tensors. Assume that tensors are of + same dense type, and that flat is given by _flatten_dense_tensors. + + Args: + flat (Tensor): flattened dense tensors to unflatten. + tensors (Iterable[Tensor]): dense tensors whose sizes will be used to + unflatten flat. + + Returns: + Unflattened dense tensors with sizes same as tensors and values from + flat. + """ + transposed_tensors = [t.transpose(0, 1) if t.dim() == 2 else t for t in tensors] + unflat = torch._C._nn.unflatten_dense_tensors(flat, transposed_tensors) + return [t.transpose(0, 1) if t.dim() == 2 else t for t in unflat] + + +def configure_zenflow(engine: "DeepSpeedEngine") -> None: + zenflow_config = engine.zenflow_config() + if zenflow_config == None: + engine.zenflow = False + return + + engine.zenflow = True + select_strategy = zenflow_config.select_strategy + + if select_strategy == 'auto': + select_strategy = "epoch" + if isinstance(zenflow_config.select_interval, int): + raise Warning( + "If use auto select strategy, select_interval will be set to 1 and select_strategy will be set to epoch, thus select_interval would be overwritten." + ) + engine.select_interval = 1 + else: + if isinstance(zenflow_config.select_interval, str): + raise ValueError("If don't use auto select strategy, select_interval must be a number.") + engine.select_interval = zenflow_config.select_interval + + if isinstance(zenflow_config.update_interval, str): + engine.auto_update = True + engine.update_interval = 0 + else: + engine.auto_update = False + engine.update_interval = int(zenflow_config.update_interval) + + if select_strategy == 'epoch': + zenflow_config.steps_per_epoch = len(engine.training_dataloader) + engine.select_interval = engine.select_interval * len(engine.training_dataloader) + + if not engine.auto_update and engine.select_interval != 0 and engine.select_interval < engine.update_interval: + raise ValueError("Select interval must be greater or equal to update interval") + + engine.overlap_step = zenflow_config.overlap_step + + engine.full_warm_up_rounds = zenflow_config.full_warm_up_rounds + + engine._config.gradient_accumulation_steps = engine.update_interval diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 5739f07cc535..3fbd3cf5e467 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -9,7 +9,8 @@ from pydantic import Field, model_validator from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel from deepspeed.utils import logger -from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum, ZenFlowConfig +from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum +from deepspeed.runtime.zenflow.zenflow_config import ZenFlowConfig # ZeRO optimization. By default, this optimization is not enabled. # Users have to configure the desired optimization (0 means disabled) in params.json as below example: diff --git a/deepspeed/runtime/zero/offload_config.py b/deepspeed/runtime/zero/offload_config.py index 1400c09e7514..ca35d7a7d169 100644 --- a/deepspeed/runtime/zero/offload_config.py +++ b/deepspeed/runtime/zero/offload_config.py @@ -6,7 +6,7 @@ from enum import Enum from pathlib import Path from pydantic import Field, model_validator -from typing import Optional, Union +from typing import Optional from deepspeed.runtime.config_utils import DeepSpeedConfigModel, pp_int @@ -100,59 +100,6 @@ def set_pipeline(self): return self -class ZenFlowConfig(DeepSpeedConfigModel): - """Configuration options for ZenFlow optimization module.""" - - topk_ratio: float = Field(0.1, ge=0.0, le=1.0) - """Ratio of top-k important gradient columns to retain (range: 0.0 to 1.0).""" - - select_strategy: str = "auto" - """Strategy for selecting important gradient indices. - Options: "auto", "step", or "epoch".""" - - select_interval: Union[str, int] = "auto" - """Interval at which to reselect important gradient indices. - Can be "auto" or a fixed integer step/epoch interval.""" - - update_interval: Union[str, int] = "auto" - """Interval for applying accumulated unimportant gradients to model parameters. - Can be "auto" or a fixed integer step interval.""" - - overlap_step: bool = False - """Whether to overlap CPU-side optimizer steps with forward/backward computation.""" - - offload: bool = False - """Whether to offload selective optimizer states to CPU to save memory.""" - - auto_ratio: float = Field(0.99, ge=0.0, le=1.0) - """Threshold used in the "auto" strategy to determine update_interval.""" - - full_warm_up_rounds: int = 0 - """Number of initial rounds during which all gradients are fully updated (no selection).""" - - steps_per_epoch: Optional[int] = Field( - default=None, - description= - "Number of steps per epoch. This field is initialized during execution and should not be set by users.", - exclude=True) - - @model_validator(mode="after") - def validate_fields(self): - if self.select_strategy not in ["auto", "step", "epoch"]: - raise ValueError('select_strategy must be one of "auto", "step", or "epoch"') - - if isinstance(self.select_interval, str) and self.select_interval != "auto": - raise ValueError('If select_interval is a string, it must be "auto"') - - if isinstance(self.update_interval, str) and self.update_interval != "auto": - raise ValueError('If update_interval is a string, it must be "auto"') - - if not isinstance(self.full_warm_up_rounds, int): - raise ValueError('full_warm_up_rounds must be an integer') - - return self - - class OffloadStateTypeEnum(str, Enum): """ Enum for internal buffer types """ optim_states = "optim_states" diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index e2336a3998da..e2b5c41fcd64 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -11,7 +11,7 @@ from typing import List, Dict from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from deepspeed.runtime.zero.zenflow import zenflow_utils +from deepspeed.runtime.zenflow import zenflow_utils from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler @@ -1933,6 +1933,8 @@ def _optimizer_step(self, group_no): if self.torch_autocast_gradscaler: self.torch_autocast_gradscaler.step(self.optimizer) self.torch_autocast_gradscaler.update() + elif self.zenflow: + self.zenflow_cpu_optimizer_step(group_no) else: self.optimizer.step() self.optimizer.param_groups = original_param_groups diff --git a/deepspeed/runtime/zero/zenflow/zenflow_utils.py b/deepspeed/runtime/zero/zenflow/zenflow_utils.py deleted file mode 100644 index a544aa3531bd..000000000000 --- a/deepspeed/runtime/zero/zenflow/zenflow_utils.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team - -import torch - - -def _flatten_dense_tensors(tensors): - """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of - same dense type. - - Since inputs are dense, the resulting tensor will be a concatenated 1D - buffer. Element-wise operation on this buffer will be equivalent to - operating individually. - - Args: - tensors (Iterable[Tensor]): dense tensors to flatten. - - Returns: - A contiguous 1D buffer containing input tensors. - """ - transposed_tensors = [t.transpose(0, 1).contiguous() if t.dim() == 2 else t for t in tensors] - return torch._C._nn.flatten_dense_tensors(transposed_tensors) - - -def _unflatten_dense_tensors(flat, tensors): - """View a flat buffer using the sizes of tensors. Assume that tensors are of - same dense type, and that flat is given by _flatten_dense_tensors. - - Args: - flat (Tensor): flattened dense tensors to unflatten. - tensors (Iterable[Tensor]): dense tensors whose sizes will be used to - unflatten flat. - - Returns: - Unflattened dense tensors with sizes same as tensors and values from - flat. - """ - transposed_tensors = [t.transpose(0, 1) if t.dim() == 2 else t for t in tensors] - unflat = torch._C._nn.unflatten_dense_tensors(flat, transposed_tensors) - return [t.transpose(0, 1) if t.dim() == 2 else t for t in unflat] diff --git a/tests/unit/runtime/zenflow/test_zf_config.py b/tests/unit/runtime/zenflow/test_zf_config.py index c0811bc32423..0ef96e82ce94 100644 --- a/tests/unit/runtime/zenflow/test_zf_config.py +++ b/tests/unit/runtime/zenflow/test_zf_config.py @@ -7,7 +7,8 @@ from pydantic import ValidationError from deepspeed.runtime.zero.config import DeepSpeedZeroConfig, ZeroStageEnum -from deepspeed.runtime.zero.offload_config import DeepSpeedZeroOffloadOptimizerConfig, ZenFlowConfig +from deepspeed.runtime.zenflow.zenflow_config import ZenFlowConfig +from deepspeed.runtime.zero.offload_config import DeepSpeedZeroOffloadOptimizerConfig def test_stage_enum_accepts_int_and_enum(): From a528fd470cebac7077e26b4dafc633b04c66827f Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Sat, 28 Jun 2025 01:17:44 -0400 Subject: [PATCH 12/26] Refactor ZenFlow function callings in DeepSpeedEngine - Simplify the `_configure_zenflow` method by assigning it a lambda function that calls `configure_zenflow(self)`. - Update the optimizer's selective learning rate synchronization to directly reference `self.optimizer._sync_selective_optimizer_lr()`. Signed-off-by: Tingfeng Lan --- deepspeed/runtime/engine.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 0699c5b665f0..0287dd16cd93 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -330,12 +330,10 @@ def __init__(self, if not isinstance(model_parameters, list): model_parameters = list(model_parameters) - self._configure_zenflow() - if self.torch_autocast_enabled(): init_autocast_params(self, self.torch_autocast_dtype(), self.torch_autocast_lower_precision_safe_modules()) - - configure_zenflow(self) + self._configure_zenflow = lambda: configure_zenflow(self) + self._configure_zenflow() if has_optimizer: self._configure_optimizer(optimizer, model_parameters) @@ -2460,7 +2458,7 @@ def step(self, lr_kwargs=None): self._step_applied = False # assume False, will flip to True if self.zenflow: - self.sync_selective_optimizer_lr() + self.optimizer._sync_selective_optimizer_lr() if self.auto_update: self.update_interval += 1 From f7bc35d7dabcc0cf1f29ffffc9b51cf8f78783af Mon Sep 17 00:00:00 2001 From: Yusen Wu Date: Tue, 1 Jul 2025 15:41:51 +0000 Subject: [PATCH 13/26] Fix bugs in ZenFlow + ZeRO Stage 1 and gradient reduction logic - Fixed the invocation of `reduce_gradients` in ZenFlow + ZeRO Stage 1 - Corrected the reduction logic in `extra_large_grad_reduc` to handle gradient aggregation properly - Fixed a bug where ZenFlow could not initialize if the user did not provide a dataset Signed-off-by: Yusen Wu --- deepspeed/ops/adam/zenflow_cpu_adam.py | 4 ++-- deepspeed/runtime/engine.py | 2 ++ deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py | 11 +++++++---- deepspeed/runtime/zenflow/zenflow_utils.py | 7 +++++-- deepspeed/runtime/zero/stage_1_and_2.py | 10 ++++++++-- 5 files changed, 24 insertions(+), 10 deletions(-) diff --git a/deepspeed/ops/adam/zenflow_cpu_adam.py b/deepspeed/ops/adam/zenflow_cpu_adam.py index b4774bf4289d..27f0206da8e0 100644 --- a/deepspeed/ops/adam/zenflow_cpu_adam.py +++ b/deepspeed/ops/adam/zenflow_cpu_adam.py @@ -13,10 +13,10 @@ def __init__(self, *args, overlap_step=False, **kwargs): super(ZenFlowCPUAdam, self).__init__(*args, **kwargs) self.overlap_step = overlap_step if not self.overlap_step: - print("ZenFlowCPUAdam initialized with overlap step.") + print("ZenFlowCPUAdam initialized with normal step.") self.step = self._sequential_step else: - print("ZenFlowCPUAdam initialized with normal step.") + print("ZenFlowCPUAdam initialized with overlap step.") self.step = self._parallel_step @torch.no_grad() diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 0287dd16cd93..1455f6e7aa77 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2183,6 +2183,8 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): else: grads = None self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size) + elif self.zenflow: + self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism) def _backward_prologue(self, loss, scale_wrt_gas=True): see_memory_usage("Engine before backward", force=self.memory_breakdown()) diff --git a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py index 25dc44a288ec..0ced06936448 100644 --- a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py +++ b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py @@ -142,7 +142,10 @@ def _configure_zenflow(self, zenflow_config): self.update_interval = int(zenflow_config.update_interval) if self.select_strategy == 'epoch': - self.select_interval = self.select_interval * zenflow_config.steps_per_epoch + if zenflow_config.steps_per_epoch is not None: + self.select_interval = self.select_interval * zenflow_config.steps_per_epoch + else: + self.select_interval = 0 if not self.auto_update and self.select_interval != 0 and self.select_interval < self.update_interval: raise ValueError("Select interval must be greater or equal to update interval") @@ -297,16 +300,16 @@ def _process_selected_fp32_groups_grad(self, tensor, total_size, communication_d rank = dist.get_rank(process_group) self.grad_buffer = torch.empty(total_size, dtype=self.dtype, device='cuda') - + + bucket = self.ipg_buckets[communication_data_type] if self.auto_update: - self.sum_buffer = torch.empty(len(self.params_in_ipg_bucket) + dist.get_world_size(group=process_group), + self.sum_buffer = torch.empty(len(bucket.params) + dist.get_world_size(group=process_group), dtype=torch.bfloat16, device='cuda') group_to_paramlist = {} # count = 0 - bucket = self.ipg_buckets[communication_data_type] for i, param_idx_in_group, param_id in bucket.params: param = self.bit16_groups[i][param_idx_in_group] diff --git a/deepspeed/runtime/zenflow/zenflow_utils.py b/deepspeed/runtime/zenflow/zenflow_utils.py index 4be8e6f2f894..cb518a6118e1 100644 --- a/deepspeed/runtime/zenflow/zenflow_utils.py +++ b/deepspeed/runtime/zenflow/zenflow_utils.py @@ -75,8 +75,11 @@ def configure_zenflow(engine: "DeepSpeedEngine") -> None: engine.update_interval = int(zenflow_config.update_interval) if select_strategy == 'epoch': - zenflow_config.steps_per_epoch = len(engine.training_dataloader) - engine.select_interval = engine.select_interval * len(engine.training_dataloader) + if engine.training_dataloader is not None: + zenflow_config.steps_per_epoch = len(engine.training_dataloader) + engine.select_interval = engine.select_interval * len(engine.training_dataloader) + else: + engine.select_interval = 0 if not engine.auto_update and engine.select_interval != 0 and engine.select_interval < engine.update_interval: raise ValueError("Select interval must be greater or equal to update interval") diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index e2b5c41fcd64..78732798d60c 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1456,7 +1456,13 @@ def reduce_ipg_grads(self): ) == param_id, "param in ipg bucket does not match extra-large param" extra_large_grad_reduc = self.get_gradient_for_reduction( self.extra_large_param_to_reduce[comm_dtype]) - self.average_tensor(extra_large_grad_reduc.view(-1), comm_dtype) + + extra_large_grad_reduc_for_average = extra_large_grad_reduc.view(-1) if not self.zenflow \ + else extra_large_grad_reduc.permute(*reversed(range(extra_large_grad_reduc.ndim))).contiguous().view(-1) + extra_large_grad_reduc.data = extra_large_grad_reduc_for_average.data.view_as(extra_large_grad_reduc) if (not self.zenflow or self.extra_large_param_to_reduce.dim() == 1) \ + else extra_large_grad_reduc_for_average.data.view_as(extra_large_grad_reduc.transpose(0, 1)) + + self.average_tensor(extra_large_grad_reduc_for_average, comm_dtype) del self.extra_large_param_to_reduce[comm_dtype] else: self.average_tensor(bucket.buffer[bucket.index].narrow(0, 0, bucket.elements), comm_dtype) @@ -1510,7 +1516,7 @@ def process_gradients(self, param, i): self.reduce_ready_partitions_and_remove_grads(param, i) def reduce_ready_partitions_and_remove_grads(self, param, i): - if self.partition_gradients or self.is_gradient_accumulation_boundary: + if self.partition_gradients or self.is_gradient_accumulation_boundary or self.zenflow: self.reduce_independent_p_g_buckets_and_remove_grads(param, i) def zero_reduced_gradients(self, partition_id, i): From 3638d7898189011e97fd67e64d81147418e060f8 Mon Sep 17 00:00:00 2001 From: Yusen Wu Date: Tue, 1 Jul 2025 15:45:19 +0000 Subject: [PATCH 14/26] Add unit tests for ZenFlow with ZeRO Stage 1 and 2 - Implemented single-GPU and distributed tests for ZenFlow with ZeRO Stage 1 and 2 - Covered various configurations of selective optimizer offloading, selection strategies (auto/step/epoch), update intervals, and warm-up rounds - Ensured ZenFlow can initialize and train under different parameter combinations Signed-off-by: Yusen Wu --- tests/unit/runtime/zenflow/test_zf.py | 113 ++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 tests/unit/runtime/zenflow/test_zf.py diff --git a/tests/unit/runtime/zenflow/test_zf.py b/tests/unit/runtime/zenflow/test_zf.py new file mode 100644 index 000000000000..2fde37084d32 --- /dev/null +++ b/tests/unit/runtime/zenflow/test_zf.py @@ -0,0 +1,113 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import pytest +import torch.distributed as dist +from deepspeed.accelerator import get_accelerator + +from unit.common import DistributedTest +from unit.simple_model import SimpleModel, random_dataset, random_dataloader +import deepspeed + +class BaseZenFlowTest: + hidden_dim = 10 + batch_size = 4 + grad_acc_steps = 1 + + def get_config_dict(self, stage, offload_selective_optimizer, select_strategy, select_interval, update_interval, full_warm_up_rounds): + config = { + "train_batch_size": self.batch_size, + "gradient_accumulation_steps": self.grad_acc_steps, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "zero_optimization": { + "stage": stage, + "offload_optimizer": {"device": "cpu"}, + "overlap_comm": True, + "zenflow": { + "topk_ratio": 0.2, + "select_strategy": select_strategy, + "select_interval": select_interval, + "update_interval": update_interval, + "overlap_step": False, + "offload": offload_selective_optimizer, + "auto_ratio": 0.99, + "full_warm_up_rounds": full_warm_up_rounds, + } + }, + "zero_allow_untested_optimizer": True, + } + + if get_accelerator().is_bf16_supported(): + config["bf16"] = {"enabled": True} + return config + + def run_training(self, config_dict): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = SimpleModel(self.hidden_dim).to(device) + train_dataset = random_dataset(total_samples=20, + hidden_dim=self.hidden_dim, + device=torch.device("cpu")) + model, optimizer, train_dataloader, _ = deepspeed.initialize(model=model, + model_parameters=model.parameters(), + config=config_dict, + training_data=train_dataset,) + + dist.barrier() + for step, batch in enumerate(train_dataloader): + inputs, labels = batch[0].to(device), batch[1].to(device) + loss = model(inputs, labels) + model.backward(loss) + model.step() + + model.destroy() + + def run_training_distributed(self, config_dict): + + model = SimpleModel(self.hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + train_dataloader = random_dataloader(model=model, total_samples=20, hidden_dim=self.hidden_dim, device=model.device) + + dist.barrier() + + for step, batch in enumerate(train_dataloader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + model.destroy() + +@pytest.mark.parametrize("stage", [1, 2]) +@pytest.mark.parametrize("full_warm_up_rounds", [0, 3]) +@pytest.mark.parametrize("offload_selective_optimizer", [True, False]) +@pytest.mark.parametrize("select_strategy,select_interval,update_interval", [ + ("auto", "auto", "auto"), + ("step", 10, 3), + ("epoch", 1, 4), +]) +def test_zenflow_single_gpu(stage, offload_selective_optimizer, select_strategy, select_interval, update_interval, full_warm_up_rounds): + tester = BaseZenFlowTest() + config_dict = tester.get_config_dict(stage, offload_selective_optimizer, select_strategy, select_interval, update_interval, full_warm_up_rounds) + tester.run_training(config_dict) + +@pytest.mark.parametrize("stage", [1, 2]) +@pytest.mark.parametrize("full_warm_up_rounds", [0, 3]) +@pytest.mark.parametrize("offload_selective_optimizer", [True, False]) +@pytest.mark.parametrize("select_strategy,select_interval,update_interval", [ + ("auto", "auto", "auto"), + ("step", 10, 3), + ("epoch", 1, 4), +]) +class TestZenFlowDistributed(DistributedTest, BaseZenFlowTest): + world_size = 2 + + def test_zenflow_distributed(self, stage, offload_selective_optimizer, select_strategy, select_interval, update_interval, full_warm_up_rounds): + config_dict = self.get_config_dict(stage, offload_selective_optimizer, select_strategy, select_interval, update_interval, full_warm_up_rounds) + self.run_training_distributed(config_dict) \ No newline at end of file From 6d6833020a8e4c7bc5897906b007a535eab36733 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Tue, 1 Jul 2025 21:12:08 -0400 Subject: [PATCH 15/26] Refactor ZenFlow integration using seperate engine file - Moved ZenFlow configuration logic to a new `engine.py` file, enhancing modularity. - Introduced new utility functions for ZenFlow: `configure_zenflow`, `is_zenflow_update_boundary`, `zenflow_step`, and `sync_zenflow_optimizer_lr`. - Removed redundant ZenFlow logic from `zenflow_utils.py`. Signed-off-by: Tingfeng Lan Co-authored-by: Yusen Wu --- deepspeed/runtime/engine.py | 52 ++------ deepspeed/runtime/zenflow/engine.py | 146 +++++++++++++++++++++ deepspeed/runtime/zenflow/zenflow_utils.py | 49 ------- 3 files changed, 157 insertions(+), 90 deletions(-) create mode 100644 deepspeed/runtime/zenflow/engine.py diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 1455f6e7aa77..75fd11c2da84 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -32,7 +32,8 @@ from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION -from deepspeed.runtime.zenflow.zenflow_utils import configure_zenflow +from deepspeed.runtime.zenflow.engine import (configure_zenflow, zenflow_step, is_zenflow_update_boundary, + sync_zenflow_optimizer_lr) from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer @@ -332,7 +333,12 @@ def __init__(self, if self.torch_autocast_enabled(): init_autocast_params(self, self.torch_autocast_dtype(), self.torch_autocast_lower_precision_safe_modules()) + self._configure_zenflow = lambda: configure_zenflow(self) + self._is_zenflow_update_boundary = lambda: is_zenflow_update_boundary(self) + self._zenflow_step = lambda lr_kwargs: zenflow_step(self, lr_kwargs) + self._sync_zenflow_optimizer_lr = lambda: sync_zenflow_optimizer_lr(self) + self._configure_zenflow() if has_optimizer: @@ -2302,21 +2308,10 @@ def is_gradient_accumulation_boundary(self): """ if self._is_gradient_accumulation_boundary is None: - if not self.zenflow: - return (self.micro_steps + 1) % \ - self.gradient_accumulation_steps() == 0 - elif not self.auto_update: - if (self.micro_steps + 1) < self.full_warm_up_rounds: - return True - else: - return ((self.micro_steps + 1 - self.full_warm_up_rounds) % self.gradient_accumulation_steps() == 0) \ - or (self.select_interval != 0 and (self.micro_steps + 1) % self.select_interval == 0) + if self.zenflow: + return self._is_zenflow_update_boundary() else: - if (self.micro_steps + 1) <= self.full_warm_up_rounds: - return True - else: - return self.optimizer.zenflow_need_update[self.optimizer.zenflow_state ^ 1] \ - or (self.select_interval != 0 and (self.micro_steps + 1) % self.select_interval == 0) + return (self.micro_steps + 1) % self.gradient_accumulation_steps() == 0 else: return self._is_gradient_accumulation_boundary @@ -2420,22 +2415,6 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}): self.global_steps += 1 self.global_samples += self.train_batch_size() - def _take_selective_parameter_step(self): - self.optimizer.selective_optimizer_step() - - def _take_lr_scheduler_step(self, lr_kwargs): - if self.lr_scheduler is not None: - try: - self.lr_scheduler.step(**(lr_kwargs or {})) - except TypeError: - # XXX Hack to work with Megatron 2.0 and DeepSpeed pipelines. - # We don't currently have a way to specify lr_kwargs from - # pipe_engine.train_batch() - self.lr_scheduler.step(self.train_batch_size()) - - def _log_selective_optimizer_timers(self): - self.optimizer.log_selective_optimizer_timers() - def step(self, lr_kwargs=None): r"""Execute the weight update step after forward and backward propagation on effective_train_batch. @@ -2489,16 +2468,7 @@ def step(self, lr_kwargs=None): report_progress = self.global_rank == 0 if self.global_rank else True if self.zenflow: - if not self.is_gradient_accumulation_boundary(): - self._take_lr_scheduler_step(lr_kwargs) - self._log_selective_optimizer_timers() - else: - if self.micro_steps + 1 >= self.full_warm_up_rounds: - self._take_selective_parameter_step() - if self.auto_update: - if dist.get_rank() == 0: - print(f"Zenflow: This is an update iter. update_interval: {self.update_interval}") - self.update_interval = 0 + self._zenflow_step(lr_kwargs) self.tput_timer.stop(global_step=self.is_gradient_accumulation_boundary(), report_speed=report_progress) diff --git a/deepspeed/runtime/zenflow/engine.py b/deepspeed/runtime/zenflow/engine.py new file mode 100644 index 000000000000..d45c87b2dd55 --- /dev/null +++ b/deepspeed/runtime/zenflow/engine.py @@ -0,0 +1,146 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed import comm as dist +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from deepspeed.runtime.engine import DeepSpeedEngine + + +def configure_zenflow(engine: "DeepSpeedEngine") -> None: + """Configure ZenFlow-related scheduling parameters on the engine. + + This function initializes ZenFlow flags (e.g., `zenflow`, `auto_update`, + `select_interval`, etc.) based on the `zenflow_config` object. It handles + selection/update strategy resolution and performs basic validation. + + Args: + engine (DeepSpeedEngine): The DeepSpeed engine to configure. + """ + zenflow_config = engine.zenflow_config() + if zenflow_config == None: + engine.zenflow = False + return + + engine.zenflow = True + select_strategy = zenflow_config.select_strategy + + if select_strategy == 'auto': + select_strategy = "epoch" + if isinstance(zenflow_config.select_interval, int): + raise Warning( + "If use auto select strategy, select_interval will be set to 1 and select_strategy will be set to epoch, thus select_interval would be overwritten." + ) + engine.select_interval = 1 + else: + if isinstance(zenflow_config.select_interval, str): + raise ValueError("If don't use auto select strategy, select_interval must be a number.") + engine.select_interval = zenflow_config.select_interval + + if isinstance(zenflow_config.update_interval, str): + engine.auto_update = True + engine.update_interval = 0 + else: + engine.auto_update = False + engine.update_interval = int(zenflow_config.update_interval) + + if select_strategy == 'epoch': + if engine.training_dataloader is not None: + zenflow_config.steps_per_epoch = len(engine.training_dataloader) + engine.select_interval = engine.select_interval * len(engine.training_dataloader) + else: + engine.select_interval = 0 + + if not engine.auto_update and engine.select_interval != 0 and engine.select_interval < engine.update_interval: + raise ValueError("Select interval must be greater or equal to update interval") + + engine.overlap_step = zenflow_config.overlap_step + + engine.full_warm_up_rounds = zenflow_config.full_warm_up_rounds + + engine._config.gradient_accumulation_steps = engine.update_interval + + +def is_zenflow_update_boundary(engine: "DeepSpeedEngine"): + """Determine whether the current step is an update boundary for ZenFlow. + + This function checks whether the engine should trigger an optimizer update + based on gradient accumulation, warmup phase, and selection/update intervals. + + Returns: + bool: True if this step is an update boundary, otherwise False. + """ + if engine.auto_update: + if (engine.micro_steps + 1) <= engine.full_warm_up_rounds: + return True + return (engine.optimizer.zenflow_need_update[engine.optimizer.zenflow_state ^ 1] + or (engine.select_interval != 0 and (engine.micro_steps + 1) % engine.select_interval == 0)) + else: + if (engine.micro_steps + 1) < engine.full_warm_up_rounds: + return True + return ((engine.micro_steps + 1 - engine.full_warm_up_rounds) % engine.gradient_accumulation_steps() == 0 + or (engine.select_interval != 0 and (engine.micro_steps + 1) % engine.select_interval == 0)) + + +def zenflow_step(engine: "DeepSpeedEngine", lr_kwargs): + """Main step logic for ZenFlow update scheduling. + + This function performs either: + - a selective optimizer update (if at accumulation boundary), + - or just a learning rate scheduler step and logging (if at accumulation iteration). + + Args: + engine (DeepSpeedEngine): The engine managing training state. + lr_kwargs (dict): Optional kwargs passed to the LR scheduler step. + """ + if engine.is_gradient_accumulation_boundary(): + if engine.micro_steps + 1 >= engine.full_warm_up_rounds: + _take_selective_parameter_step(engine) + if engine.auto_update: + if dist.get_rank() == 0: + print(f"Zenflow: This is an update iter. update_interval: {engine.update_interval}") + engine.update_interval = 0 + else: + _take_lr_scheduler_step(engine, lr_kwargs) + _log_selective_optimizer_timers(engine) + + +def _take_selective_parameter_step(engine: "DeepSpeedEngine"): + """ + Trigger a step on the selective optimizer. + """ + engine.optimizer.selective_optimizer_step() + + +def _take_lr_scheduler_step(engine: "DeepSpeedEngine", lr_kwargs): + """ + Take a step on the learning rate scheduler. + """ + if engine.lr_scheduler is not None: + try: + engine.lr_scheduler.step(**(lr_kwargs or {})) + except TypeError: + # XXX Hack to work with Megatron 2.0 and DeepSpeed pipelines. + # We don't currently have a way to specify lr_kwargs from + # pipe_engine.train_batch() + engine.lr_scheduler.step(engine.train_batch_size()) + + +def _log_selective_optimizer_timers(engine): + """ + Log the selective optimizer timers. + """ + engine.optimizer.log_selective_optimizer_timers() + + +def sync_zenflow_optimizer_lr(engine: "DeepSpeedEngine"): + """ + Synchronize the learning rate of the selective optimizer. + If auto_update is enabled, increment the update interval. + """ + engine.optimizer._sync_selective_optimizer_lr() + if engine.auto_update: + engine.update_interval += 1 diff --git a/deepspeed/runtime/zenflow/zenflow_utils.py b/deepspeed/runtime/zenflow/zenflow_utils.py index cb518a6118e1..a544aa3531bd 100644 --- a/deepspeed/runtime/zenflow/zenflow_utils.py +++ b/deepspeed/runtime/zenflow/zenflow_utils.py @@ -4,10 +4,6 @@ # DeepSpeed Team import torch -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from deepspeed.runtime.engine import DeepSpeedEngine def _flatten_dense_tensors(tensors): @@ -44,48 +40,3 @@ def _unflatten_dense_tensors(flat, tensors): transposed_tensors = [t.transpose(0, 1) if t.dim() == 2 else t for t in tensors] unflat = torch._C._nn.unflatten_dense_tensors(flat, transposed_tensors) return [t.transpose(0, 1) if t.dim() == 2 else t for t in unflat] - - -def configure_zenflow(engine: "DeepSpeedEngine") -> None: - zenflow_config = engine.zenflow_config() - if zenflow_config == None: - engine.zenflow = False - return - - engine.zenflow = True - select_strategy = zenflow_config.select_strategy - - if select_strategy == 'auto': - select_strategy = "epoch" - if isinstance(zenflow_config.select_interval, int): - raise Warning( - "If use auto select strategy, select_interval will be set to 1 and select_strategy will be set to epoch, thus select_interval would be overwritten." - ) - engine.select_interval = 1 - else: - if isinstance(zenflow_config.select_interval, str): - raise ValueError("If don't use auto select strategy, select_interval must be a number.") - engine.select_interval = zenflow_config.select_interval - - if isinstance(zenflow_config.update_interval, str): - engine.auto_update = True - engine.update_interval = 0 - else: - engine.auto_update = False - engine.update_interval = int(zenflow_config.update_interval) - - if select_strategy == 'epoch': - if engine.training_dataloader is not None: - zenflow_config.steps_per_epoch = len(engine.training_dataloader) - engine.select_interval = engine.select_interval * len(engine.training_dataloader) - else: - engine.select_interval = 0 - - if not engine.auto_update and engine.select_interval != 0 and engine.select_interval < engine.update_interval: - raise ValueError("Select interval must be greater or equal to update interval") - - engine.overlap_step = zenflow_config.overlap_step - - engine.full_warm_up_rounds = zenflow_config.full_warm_up_rounds - - engine._config.gradient_accumulation_steps = engine.update_interval From 913f9a7daf0c0246397e526177a5156827019937 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Tue, 1 Jul 2025 23:43:12 -0400 Subject: [PATCH 16/26] Fix missing `[comm_dtype]` and format code Signed-off-by: Tingfeng Lan --- deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py | 2 +- deepspeed/runtime/zero/stage_1_and_2.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py index 0ced06936448..7ac092a86a20 100644 --- a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py +++ b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py @@ -300,7 +300,7 @@ def _process_selected_fp32_groups_grad(self, tensor, total_size, communication_d rank = dist.get_rank(process_group) self.grad_buffer = torch.empty(total_size, dtype=self.dtype, device='cuda') - + bucket = self.ipg_buckets[communication_data_type] if self.auto_update: self.sum_buffer = torch.empty(len(bucket.params) + dist.get_world_size(group=process_group), diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 78732798d60c..78937bacfa7a 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1456,12 +1456,12 @@ def reduce_ipg_grads(self): ) == param_id, "param in ipg bucket does not match extra-large param" extra_large_grad_reduc = self.get_gradient_for_reduction( self.extra_large_param_to_reduce[comm_dtype]) - + extra_large_grad_reduc_for_average = extra_large_grad_reduc.view(-1) if not self.zenflow \ else extra_large_grad_reduc.permute(*reversed(range(extra_large_grad_reduc.ndim))).contiguous().view(-1) - extra_large_grad_reduc.data = extra_large_grad_reduc_for_average.data.view_as(extra_large_grad_reduc) if (not self.zenflow or self.extra_large_param_to_reduce.dim() == 1) \ + extra_large_grad_reduc.data = extra_large_grad_reduc_for_average.data.view_as(extra_large_grad_reduc) if (not self.zenflow or self.extra_large_param_to_reduce[comm_dtype].dim() == 1) \ else extra_large_grad_reduc_for_average.data.view_as(extra_large_grad_reduc.transpose(0, 1)) - + self.average_tensor(extra_large_grad_reduc_for_average, comm_dtype) del self.extra_large_param_to_reduce[comm_dtype] else: From bce0a7f828b8827671e9680aee5f7002a07ec6df Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Thu, 3 Jul 2025 00:33:17 -0400 Subject: [PATCH 17/26] Update CPUADAM core range calculation in zenflow_stage_1_and_2.py - Adjusted CPUADAM_CORE_START and CPUADAM_CORE_END to dynamically use the total number of CPU cores available. Signed-off-by: Tingfeng Lan Co-authored-by: Yusen Wu --- deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py index 7ac092a86a20..ed3239e801c9 100644 --- a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py +++ b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py @@ -656,8 +656,9 @@ def zenflow_optimizer_process(pipe, curr_rank, total_rank, param_groups, shared_ os.environ["CUDA_VISIBLE_DEVICES"] = "" disable_accelerator() - CPUADAM_CORE_START = 65 - CPUADAM_CORE_END = 112 + TOTAL_CORES = os.cpu_count() + CPUADAM_CORE_START = 0 + CPUADAM_CORE_END = TOTAL_CORES TOTAL_CORES = CPUADAM_CORE_END - CPUADAM_CORE_START cores_per_rank = TOTAL_CORES // total_rank From 0ef3fafb8d88e8db3e7eaa8b948e75a9db723296 Mon Sep 17 00:00:00 2001 From: Yusen Wu Date: Wed, 9 Jul 2025 04:23:11 +0000 Subject: [PATCH 18/26] Fix bugs in ZenFlow unit tests - Changed backend for single-GPU tests to avoid triggering mpi4py initialization - Prevents test crashes in environments without MPI installed Signed-off-by: Yusen Wu --- tests/unit/runtime/zenflow/test_zf.py | 61 ++++++++++++--------------- 1 file changed, 28 insertions(+), 33 deletions(-) diff --git a/tests/unit/runtime/zenflow/test_zf.py b/tests/unit/runtime/zenflow/test_zf.py index 2fde37084d32..f131da8bf924 100644 --- a/tests/unit/runtime/zenflow/test_zf.py +++ b/tests/unit/runtime/zenflow/test_zf.py @@ -3,21 +3,22 @@ # DeepSpeed Team -import torch import pytest -import torch.distributed as dist +import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest -from unit.simple_model import SimpleModel, random_dataset, random_dataloader +from unit.simple_model import SimpleModel, random_dataloader import deepspeed + class BaseZenFlowTest: hidden_dim = 10 batch_size = 4 grad_acc_steps = 1 - def get_config_dict(self, stage, offload_selective_optimizer, select_strategy, select_interval, update_interval, full_warm_up_rounds): + def get_config_dict(self, stage, offload_selective_optimizer, select_strategy, select_interval, update_interval, + full_warm_up_rounds): config = { "train_batch_size": self.batch_size, "gradient_accumulation_steps": self.grad_acc_steps, @@ -30,7 +31,9 @@ def get_config_dict(self, stage, offload_selective_optimizer, select_strategy, s }, "zero_optimization": { "stage": stage, - "offload_optimizer": {"device": "cpu"}, + "offload_optimizer": { + "device": "cpu" + }, "overlap_comm": True, "zenflow": { "topk_ratio": 0.2, @@ -50,31 +53,14 @@ def get_config_dict(self, stage, offload_selective_optimizer, select_strategy, s config["bf16"] = {"enabled": True} return config - def run_training(self, config_dict): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = SimpleModel(self.hidden_dim).to(device) - train_dataset = random_dataset(total_samples=20, - hidden_dim=self.hidden_dim, - device=torch.device("cpu")) - model, optimizer, train_dataloader, _ = deepspeed.initialize(model=model, - model_parameters=model.parameters(), - config=config_dict, - training_data=train_dataset,) - - dist.barrier() - for step, batch in enumerate(train_dataloader): - inputs, labels = batch[0].to(device), batch[1].to(device) - loss = model(inputs, labels) - model.backward(loss) - model.step() - - model.destroy() - def run_training_distributed(self, config_dict): model = SimpleModel(self.hidden_dim) model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) - train_dataloader = random_dataloader(model=model, total_samples=20, hidden_dim=self.hidden_dim, device=model.device) + train_dataloader = random_dataloader(model=model, + total_samples=20, + hidden_dim=self.hidden_dim, + device=model.device) dist.barrier() @@ -84,6 +70,7 @@ def run_training_distributed(self, config_dict): model.step() model.destroy() + @pytest.mark.parametrize("stage", [1, 2]) @pytest.mark.parametrize("full_warm_up_rounds", [0, 3]) @pytest.mark.parametrize("offload_selective_optimizer", [True, False]) @@ -92,10 +79,16 @@ def run_training_distributed(self, config_dict): ("step", 10, 3), ("epoch", 1, 4), ]) -def test_zenflow_single_gpu(stage, offload_selective_optimizer, select_strategy, select_interval, update_interval, full_warm_up_rounds): - tester = BaseZenFlowTest() - config_dict = tester.get_config_dict(stage, offload_selective_optimizer, select_strategy, select_interval, update_interval, full_warm_up_rounds) - tester.run_training(config_dict) +class TestZenFlowSingleGPU(DistributedTest, BaseZenFlowTest): + world_size = 1 + + def test_zenflow_single_gpu(self, stage, offload_selective_optimizer, select_strategy, select_interval, + update_interval, full_warm_up_rounds): + tester = BaseZenFlowTest() + config_dict = tester.get_config_dict(stage, offload_selective_optimizer, select_strategy, select_interval, + update_interval, full_warm_up_rounds) + tester.run_training_distributed(config_dict) + @pytest.mark.parametrize("stage", [1, 2]) @pytest.mark.parametrize("full_warm_up_rounds", [0, 3]) @@ -108,6 +101,8 @@ def test_zenflow_single_gpu(stage, offload_selective_optimizer, select_strategy, class TestZenFlowDistributed(DistributedTest, BaseZenFlowTest): world_size = 2 - def test_zenflow_distributed(self, stage, offload_selective_optimizer, select_strategy, select_interval, update_interval, full_warm_up_rounds): - config_dict = self.get_config_dict(stage, offload_selective_optimizer, select_strategy, select_interval, update_interval, full_warm_up_rounds) - self.run_training_distributed(config_dict) \ No newline at end of file + def test_zenflow_distributed(self, stage, offload_selective_optimizer, select_strategy, select_interval, + update_interval, full_warm_up_rounds): + config_dict = self.get_config_dict(stage, offload_selective_optimizer, select_strategy, select_interval, + update_interval, full_warm_up_rounds) + self.run_training_distributed(config_dict) From 8d6b6f34876fadc88187ad53dfc8c93c474c4ba7 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Tue, 15 Jul 2025 21:43:14 -0400 Subject: [PATCH 19/26] Fix: Add PyTorch version check for ZenFlow configuration Signed-off-by: Tingfeng Lan --- deepspeed/runtime/zenflow/engine.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/deepspeed/runtime/zenflow/engine.py b/deepspeed/runtime/zenflow/engine.py index d45c87b2dd55..8eaa2914f671 100644 --- a/deepspeed/runtime/zenflow/engine.py +++ b/deepspeed/runtime/zenflow/engine.py @@ -5,6 +5,7 @@ from deepspeed import comm as dist from typing import TYPE_CHECKING +from deepspeed.utils.torch import required_torch_version if TYPE_CHECKING: from deepspeed.runtime.engine import DeepSpeedEngine @@ -24,6 +25,10 @@ def configure_zenflow(engine: "DeepSpeedEngine") -> None: if zenflow_config == None: engine.zenflow = False return + if not required_torch_version(min_version=2.0): + raise ValueError( + "Please use PyTorch 2.0 or later to enable ZenFlow. Alternatively, omit `zenflow` config in the config file to fall back to the default ZeRO-Offload optimizer." + ) engine.zenflow = True select_strategy = zenflow_config.select_strategy From 891ac093ac5cd10eccc0bef9e1f2d25594f5655b Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Wed, 16 Jul 2025 19:25:11 -0400 Subject: [PATCH 20/26] Enhance ZenFlow compatibility checks for PyTorch version Signed-off-by: Tingfeng Lan --- deepspeed/ops/adam/zenflow_torch_adam.py | 49 +++++++++++++++++++----- deepspeed/runtime/zenflow/engine.py | 4 +- 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/deepspeed/ops/adam/zenflow_torch_adam.py b/deepspeed/ops/adam/zenflow_torch_adam.py index af30698a0327..a915c8e13d09 100644 --- a/deepspeed/ops/adam/zenflow_torch_adam.py +++ b/deepspeed/ops/adam/zenflow_torch_adam.py @@ -7,21 +7,45 @@ from typing import cast, List, Optional, Tuple, Union from torch import Tensor -from torch.optim.optimizer import ( - _default_to_fused_or_foreach, - _disable_dynamo_if_unsupported, - _get_capturable_supported_devices, - _get_value, - _stack_if_compiling, - _view_as_real, - DeviceDict, - Optimizer, -) +from deepspeed.utils.torch import required_torch_version + +# Check if we have PyTorch >= 2.0 for ZenFlow features +_ZENFLOW_AVAILABLE = required_torch_version(min_version=2.1) + +if _ZENFLOW_AVAILABLE: + try: + from torch.optim.optimizer import ( + _default_to_fused_or_foreach, + _disable_dynamo_if_unsupported, + _get_capturable_supported_devices, + _get_value, + _stack_if_compiling, + _view_as_real, + DeviceDict, + Optimizer, + ) + except ImportError as e: + print(f"[WARNING] ZenFlow disabled: torch internal optimizer symbols could not be imported: {e}") + _ZENFLOW_AVAILABLE = False +else: + # safe disable dynamo if unsupported + def _disable_dynamo_if_unsupported(**kwargs): + + def wrapper(fn): + return fn + + return wrapper + + _ZENFLOW_AVAILABLE = False class ZenFlowSelectiveAdamW(torch.optim.AdamW): def __init__(self, *args, offload=False, bucket_size=5e8, **kwargs): + if not _ZENFLOW_AVAILABLE: + raise RuntimeError("ZenFlow features are not available with PyTorch < 2.0. " + "Please upgrade to PyTorch 2.0+ to use ZenFlow, or omit 'zenflow' " + "from your DeepSpeed configuration to use the default ZeRO-Offload optimizer.") super(ZenFlowSelectiveAdamW, self).__init__(*args, **kwargs) self.offload = offload @@ -688,6 +712,11 @@ def adamw( See :class:`~torch.optim.AdamW` for details. """ + if not _ZENFLOW_AVAILABLE: + raise RuntimeError("ZenFlow adamw function is not available with PyTorch < 2.0. " + "Please upgrade to PyTorch 2.0+ to use ZenFlow, or omit 'zenflow' " + "from your DeepSpeed configuration to use the default ZeRO-Offload optimizer.") + if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") diff --git a/deepspeed/runtime/zenflow/engine.py b/deepspeed/runtime/zenflow/engine.py index 8eaa2914f671..1d46f7c63e63 100644 --- a/deepspeed/runtime/zenflow/engine.py +++ b/deepspeed/runtime/zenflow/engine.py @@ -25,9 +25,9 @@ def configure_zenflow(engine: "DeepSpeedEngine") -> None: if zenflow_config == None: engine.zenflow = False return - if not required_torch_version(min_version=2.0): + if not required_torch_version(min_version=2.1): raise ValueError( - "Please use PyTorch 2.0 or later to enable ZenFlow. Alternatively, omit `zenflow` config in the config file to fall back to the default ZeRO-Offload optimizer." + "Please use PyTorch 2.1 or later to enable ZenFlow. Alternatively, omit `zenflow` config in the config file to fall back to the default ZeRO-Offload optimizer." ) engine.zenflow = True From d2d1a06e91aaa4e922134699e8780edf6ab53da9 Mon Sep 17 00:00:00 2001 From: Yusen Wu Date: Fri, 1 Aug 2025 10:26:03 -0400 Subject: [PATCH 21/26] Fix bugs in ZenFlow unit tests when using CPU Torch - Fixed errors encountered when running the cpu-torch-latest workflow - Skip ZenFlow tests when device is CPU, as ZenFlow does not support CPU-only Torch Signed-off-by: Yusen Wu --- tests/unit/runtime/zenflow/test_zf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit/runtime/zenflow/test_zf.py b/tests/unit/runtime/zenflow/test_zf.py index f131da8bf924..e38881b765a4 100644 --- a/tests/unit/runtime/zenflow/test_zf.py +++ b/tests/unit/runtime/zenflow/test_zf.py @@ -55,6 +55,9 @@ def get_config_dict(self, stage, offload_selective_optimizer, select_strategy, s def run_training_distributed(self, config_dict): + if get_accelerator().device_name() == "cpu": + return + model = SimpleModel(self.hidden_dim) model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) train_dataloader = random_dataloader(model=model, From f3b227694e664f0e1f0d056ed97bcdc71727ad2f Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Sat, 2 Aug 2025 13:04:27 -0400 Subject: [PATCH 22/26] Added TODO comments to indicate the need for removing ZenFlow-specific calls from the vanilla ZeroOptimizer. Signed-off-by: Tingfeng Lan --- deepspeed/runtime/zero/stage_1_and_2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 78937bacfa7a..1fbea6d9e3a8 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -171,6 +171,7 @@ def __init__(self, self.cpu_offload = False self.cpu_offload_pin_memory = False + # TODO: Remove zenflow-specific call from vanilla ZeroOptimizer, try to isolate zenflow-specific code into sub-class zenflow_zero_optimizer self.zenflow = True if zenflow_config is not None else False if dist.get_rank() == 0: @@ -1939,6 +1940,7 @@ def _optimizer_step(self, group_no): if self.torch_autocast_gradscaler: self.torch_autocast_gradscaler.step(self.optimizer) self.torch_autocast_gradscaler.update() + # TODO: Remove zenflow-specific call from vanilla ZeroOptimizer elif self.zenflow: self.zenflow_cpu_optimizer_step(group_no) else: From bbb6f7441e351bf4970c16abdd6e4d41d7012248 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Sat, 2 Aug 2025 13:15:04 -0400 Subject: [PATCH 23/26] Fix formatting in test_zf.py Signed-off-by: Tingfeng Lan Co-authored-by: Yusen Wu --- tests/unit/runtime/zenflow/test_zf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/runtime/zenflow/test_zf.py b/tests/unit/runtime/zenflow/test_zf.py index e38881b765a4..57d58d02facb 100644 --- a/tests/unit/runtime/zenflow/test_zf.py +++ b/tests/unit/runtime/zenflow/test_zf.py @@ -57,7 +57,7 @@ def run_training_distributed(self, config_dict): if get_accelerator().device_name() == "cpu": return - + model = SimpleModel(self.hidden_dim) model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) train_dataloader = random_dataloader(model=model, From 9f4fb58541022743333862a2d5ace691f451a26a Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Sat, 2 Aug 2025 13:14:10 -0400 Subject: [PATCH 24/26] Update docs/_tutorials/zenflow.md Signed-off-by: Olatunji Ruwase Signed-off-by: Tingfeng Lan Co-authored-by: Olatunji Ruwase --- docs/_tutorials/zenflow.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_tutorials/zenflow.md b/docs/_tutorials/zenflow.md index 9bb24281d4a5..cd645bd03c6f 100644 --- a/docs/_tutorials/zenflow.md +++ b/docs/_tutorials/zenflow.md @@ -3,7 +3,7 @@ title: "ZenFlow" tags: training --- -ZenFlow is an extension of ZeRO-Offload that decouples and asynchronously updates gradients during training. It reduces CPU-induced stalls when using offload optimizers, enabling smoother and faster training. Like ZeRO-Offload, ZenFlow requires no code changes—only configuration updates in your DeepSpeed JSON file. +ZenFlow is an extension of ZeRO-Offload that decouples and asynchronously updates gradients during training. It reduces CPU-induced stalls when using offload optimizers, enabling smoother and faster training. Like ZeRO-Offload, ZenFlow requires no code changes, only configuration updates in your DeepSpeed JSON file. We recommend that you read the tutorials on [Getting Started](/getting-started/) and [ZeRO](/tutorials/zero/) before stepping through this tutorial. ZenFlow builds on top of [ZeRO-Offload](/tutorials/zero-offload/), so shared setup details can be found there. From 938e8a3c728c16b3fd56e5312f820cddcc120215 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Sun, 10 Aug 2025 14:00:00 -0400 Subject: [PATCH 25/26] Fix copyrights. Signed-off-by: Tingfeng Lan Co-authored-by: Olatunji Ruwase --- deepspeed/ops/adam/__init__.py | 2 +- deepspeed/ops/adam/zenflow_cpu_adam.py | 2 +- deepspeed/ops/adam/zenflow_torch_adam.py | 2 +- deepspeed/runtime/zenflow/__init__.py | 2 +- deepspeed/runtime/zenflow/engine.py | 2 +- deepspeed/runtime/zenflow/zenflow_config.py | 2 +- deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py | 2 +- deepspeed/runtime/zenflow/zenflow_utils.py | 2 +- tests/unit/ops/adam/test_zf_torch_adam.py | 2 +- tests/unit/runtime/zenflow/test_zf.py | 2 +- tests/unit/runtime/zenflow/test_zf_config.py | 2 +- 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/deepspeed/ops/adam/__init__.py b/deepspeed/ops/adam/__init__.py index 82dfa114ac9a..5c657db4f270 100755 --- a/deepspeed/ops/adam/__init__.py +++ b/deepspeed/ops/adam/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) Microsoft Corporation. +# Copyright (c) DeepSpeed Team. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team diff --git a/deepspeed/ops/adam/zenflow_cpu_adam.py b/deepspeed/ops/adam/zenflow_cpu_adam.py index 27f0206da8e0..0809d7a0f7e0 100644 --- a/deepspeed/ops/adam/zenflow_cpu_adam.py +++ b/deepspeed/ops/adam/zenflow_cpu_adam.py @@ -1,4 +1,4 @@ -# Copyright (c) Microsoft Corporation. +# Copyright (c) DeepSpeed Team. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team diff --git a/deepspeed/ops/adam/zenflow_torch_adam.py b/deepspeed/ops/adam/zenflow_torch_adam.py index a915c8e13d09..e3b415108065 100644 --- a/deepspeed/ops/adam/zenflow_torch_adam.py +++ b/deepspeed/ops/adam/zenflow_torch_adam.py @@ -1,4 +1,4 @@ -# Copyright (c) Microsoft Corporation. +# Copyright (c) DeepSpeed Team. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team diff --git a/deepspeed/runtime/zenflow/__init__.py b/deepspeed/runtime/zenflow/__init__.py index 208299fb8c50..6f5f5619004b 100644 --- a/deepspeed/runtime/zenflow/__init__.py +++ b/deepspeed/runtime/zenflow/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) Microsoft Corporation. +# Copyright (c) DeepSpeed Team. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team diff --git a/deepspeed/runtime/zenflow/engine.py b/deepspeed/runtime/zenflow/engine.py index 1d46f7c63e63..2236d097169b 100644 --- a/deepspeed/runtime/zenflow/engine.py +++ b/deepspeed/runtime/zenflow/engine.py @@ -1,4 +1,4 @@ -# Copyright (c) Microsoft Corporation. +# Copyright (c) DeepSpeed Team. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team diff --git a/deepspeed/runtime/zenflow/zenflow_config.py b/deepspeed/runtime/zenflow/zenflow_config.py index 416607c6b387..a11a7dd7e68e 100644 --- a/deepspeed/runtime/zenflow/zenflow_config.py +++ b/deepspeed/runtime/zenflow/zenflow_config.py @@ -1,4 +1,4 @@ -# Copyright (c) Microsoft Corporation. +# Copyright (c) DeepSpeed Team. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team diff --git a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py index ed3239e801c9..c8fda278868d 100644 --- a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py +++ b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py @@ -1,4 +1,4 @@ -# Copyright (c) Microsoft Corporation. +# Copyright (c) DeepSpeed Team. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team diff --git a/deepspeed/runtime/zenflow/zenflow_utils.py b/deepspeed/runtime/zenflow/zenflow_utils.py index a544aa3531bd..4d2fcaaa4b86 100644 --- a/deepspeed/runtime/zenflow/zenflow_utils.py +++ b/deepspeed/runtime/zenflow/zenflow_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) Microsoft Corporation. +# Copyright (c) DeepSpeed Team. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team diff --git a/tests/unit/ops/adam/test_zf_torch_adam.py b/tests/unit/ops/adam/test_zf_torch_adam.py index 76915d1bb32c..c7163ffe2f09 100644 --- a/tests/unit/ops/adam/test_zf_torch_adam.py +++ b/tests/unit/ops/adam/test_zf_torch_adam.py @@ -1,4 +1,4 @@ -# Copyright (c) Microsoft Corporation. +# Copyright (c) DeepSpeed Team. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team diff --git a/tests/unit/runtime/zenflow/test_zf.py b/tests/unit/runtime/zenflow/test_zf.py index 57d58d02facb..3294902bef67 100644 --- a/tests/unit/runtime/zenflow/test_zf.py +++ b/tests/unit/runtime/zenflow/test_zf.py @@ -1,4 +1,4 @@ -# Copyright (c) Microsoft Corporation. +# Copyright (c) DeepSpeed Team. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team diff --git a/tests/unit/runtime/zenflow/test_zf_config.py b/tests/unit/runtime/zenflow/test_zf_config.py index 0ef96e82ce94..647b7f82f2e9 100644 --- a/tests/unit/runtime/zenflow/test_zf_config.py +++ b/tests/unit/runtime/zenflow/test_zf_config.py @@ -1,4 +1,4 @@ -# Copyright (c) Microsoft Corporation. +# Copyright (c) DeepSpeed Team. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team From 8951fa0882894c9088ce1e0bab7298654142eab7 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Sun, 10 Aug 2025 14:34:39 -0400 Subject: [PATCH 26/26] Remove CUDA specific code. Signed-off-by: Tingfeng Lan Co-authored-by: Guokai Ma --- deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py index c8fda278868d..cbb818d8072f 100644 --- a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py +++ b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py @@ -190,7 +190,7 @@ def update_selected_channels(self, tensor, total_size, communication_data_type): process_group = self.dp_process_group rank = dist.get_rank(process_group) - self.index_buffer = torch.empty(total_size, dtype=torch.int32, device='cuda') + self.index_buffer = torch.empty(total_size, dtype=torch.int32, device=get_accelerator().current_device_name()) # count = 0 bucket = self.ipg_buckets[communication_data_type] @@ -299,13 +299,13 @@ def _process_selected_fp32_groups_grad(self, tensor, total_size, communication_d process_group = self.dp_process_group rank = dist.get_rank(process_group) - self.grad_buffer = torch.empty(total_size, dtype=self.dtype, device='cuda') + self.grad_buffer = torch.empty(total_size, dtype=self.dtype, device=get_accelerator().current_device_name()) bucket = self.ipg_buckets[communication_data_type] if self.auto_update: self.sum_buffer = torch.empty(len(bucket.params) + dist.get_world_size(group=process_group), dtype=torch.bfloat16, - device='cuda') + device=get_accelerator().current_device_name()) group_to_paramlist = {} @@ -653,7 +653,6 @@ def disable_accelerator(): def zenflow_optimizer_process(pipe, curr_rank, total_rank, param_groups, shared_overlap_grad_map, shared_stale_param_map): - os.environ["CUDA_VISIBLE_DEVICES"] = "" disable_accelerator() TOTAL_CORES = os.cpu_count()