Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions training/DeepSpeed-ZenFlow/benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# ZenFlow Benchmark Example


Please install DeepSpeed via pip install deepspeed if you haven't already done so.

```bash
pip install -r requirements.txt
```


The script `zf_benchmark.py ` demonstrates how to offload the state of a model. Here is the example usage.

```python
$ deepspeed --num_gpus=4 zf_benchmark.py --hidden_dim 4096 --nlayers 4 --iteration 5 --pin_memory_opts 1 --topk_ratios 0.1 --update_intervals 2 --overlap_steps
...
time (ms) | selective_optimizer_update: 19.20 | selective_optimizer_process: 28.80 | selective_optimizer_sync: 0.05
time (ms) | fwd_microstep: 54.76 | bwd_microstep: 122.95 | bwd_inner_microstep: 12.22 | bwd_allreduce_microstep: 103.64 | step_microstep: 0.34
Step 0 time: 178.66ms
time (ms) | optimizer_allgather: 26.19 | optimizer_gradients: 26.06 | optimizer_step: 128.20
time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 0.57 | selective_optimizer_step: 1.48 | selective_optimizer_sync: 0.00
time (ms) | fwd_microstep: 0.38 | bwd_microstep: 57.88 | bwd_inner_microstep: 1.06 | bwd_allreduce_microstep: 56.50 | step_microstep: 183.27
time (ms) | fwd: 55.15 | bwd: 180.82 | bwd_inner: 13.28 | bwd_allreduce: 160.15 | step: 183.61
Step 1 time: 242.16ms
time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 1.58 | selective_optimizer_step: 0.00 | selective_optimizer_sync: 0.00
time (ms) | fwd_microstep: 0.30 | bwd_microstep: 16.73 | bwd_inner_microstep: 1.39 | bwd_allreduce_microstep: 14.96 | step_microstep: 0.20
Step 2 time: 17.60ms
time (ms) | optimizer_allgather: 0.65 | optimizer_gradients: 16.95 | optimizer_step: 108.45
time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 0.56 | selective_optimizer_step: 1.42 | selective_optimizer_sync: 0.00
time (ms) | fwd_microstep: 0.29 | bwd_microstep: 36.65 | bwd_inner_microstep: 0.95 | bwd_allreduce_microstep: 35.51 | step_microstep: 128.57
time (ms) | fwd: 0.59 | bwd: 53.39 | bwd_inner: 2.33 | bwd_allreduce: 50.48 | step: 128.77
Step 3 time: 166.10ms
time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 1.57 | selective_optimizer_step: 0.00 | selective_optimizer_sync: 0.00
time (ms) | fwd_microstep: 0.31 | bwd_microstep: 15.47 | bwd_inner_microstep: 1.33 | bwd_allreduce_microstep: 13.97 | step_microstep: 0.23
...
[Summary] pin_memory=False topk_ratio=0.1 update_interval=2 overlap_step=False avg_accumulation_step=16.77ms avg_update_step=171.38ms
```

`run_benchmark.sh` shows how to run the script with different configurations. The script outputs the time for offloading and loading the states.

```python
$ ./run_benchmark.sh
...
+---------+--------------+--------------+-------------------+----------------+-------------+------------+-----------+-----------+--------------------------------+
| trial | pin_memory | topk_ratio | update_interval | overlap_step | num_steps | avg_step | avg_bwd | avg_fwd | avg_selective_optimizer_step |
|---------+--------------+--------------+-------------------+----------------+-------------+------------+-----------+-----------+--------------------------------|
| 1 | False | 0.1 | 2 | False | 30 | 24.0153 | 12.8377 | 1.91733 | 0.247 |
| 1 | False | 0.1 | 2 | True | 28 | 805.425 | 22.5604 | 1.96821 | 0.345714 |
| 1 | False | 0.1 | 4 | False | 50 | 14.2108 | 10.9072 | 1.2436 | 0.1484 |
| 1 | False | 0.1 | 4 | True | 48 | 459.326 | 16.0385 | 1.30125 | 0.221667 |
| 1 | False | 0.2 | 2 | False | 30 | 22.6567 | 12.6463 | 2.421 | 0.346 |
| 1 | False | 0.2 | 2 | True | 28 | 817.919 | 22.1079 | 2.06179 | 0.450714 |
| 1 | False | 0.2 | 4 | False | 50 | 14.12 | 9.4714 | 1.1766 | 0.2072 |
| 1 | False | 0.2 | 4 | True | 48 | 471.339 | 15.945 | 1.2675 | 0.262292 |...
```
65 changes: 65 additions & 0 deletions training/DeepSpeed-ZenFlow/benchmark/output_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import re
from collections import defaultdict
import pandas as pd
from tabulate import tabulate

def parse_log_file(log_file_path):
with open(log_file_path, 'r') as f:
lines = f.readlines()

# Regex patterns
trial_header_re = re.compile(
r"\[Trial (\d+)] pin_memory=(\d), topk=([\d.]+), update=(\d+), overlap_step=(\d+) \(MASTER_PORT=\d+\)"
)
time_metrics_re = re.compile(r"\|\s*([^:|]+):\s*([\d.]+)")

trials = []
current_config = None
current_step_metrics = []

def finalize_trial():
if current_config and current_step_metrics:
# Get all unique keys
all_keys = set()
for step in current_step_metrics:
all_keys.update(step.keys())
# Aggregate and average
agg = {k: 0.0 for k in all_keys}
for step in current_step_metrics:
for k in all_keys:
agg[k] += step.get(k, 0.0)
avg = {f"avg_{k}": agg[k] / len(current_step_metrics) for k in all_keys}
trials.append({**current_config, **avg, "num_steps": len(current_step_metrics)})

for line in lines:
header_match = trial_header_re.search(line)
if header_match:
finalize_trial()
trial_id, pin_memory, topk, update, overlap = header_match.groups()
current_config = {
"trial": int(trial_id),
"pin_memory": bool(int(pin_memory)),
"topk_ratio": float(topk),
"update_interval": int(update),
"overlap_step": bool(int(overlap))
}
current_step_metrics = []
continue

if "[Rank 0]" in line and "time (ms)" in line:
metrics = {k.strip(): float(v) for k, v in time_metrics_re.findall(line)}
current_step_metrics.append(metrics)

finalize_trial()
return pd.DataFrame(trials)

if __name__ == "__main__":

log_file = "zf_benchmark.log"
df = parse_log_file(log_file)
df = df.sort_values(by=["topk_ratio", "update_interval", "overlap_step", "pin_memory"])
cols_to_display = [
"trial", "topk_ratio", "update_interval", "overlap_step", "pin_memory", "num_steps",
"avg_step", "avg_bwd", "avg_fwd", "avg_selective_optimizer_step"
]
print(tabulate(df[cols_to_display], headers="keys", tablefmt="psql", showindex=False))
7 changes: 7 additions & 0 deletions training/DeepSpeed-ZenFlow/benchmark/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
torch>=2.5.1
deepspeed>=0.16.0
datasets>=2.14.1
transformers>=4.37.2
numpy>=1.21.0
tabulate
pandas
36 changes: 36 additions & 0 deletions training/DeepSpeed-ZenFlow/benchmark/run_benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/bin/bash

NGPUS=2
HIDDEN_SIZE=4096
NUM_LAYERS=4
TRIALS=1

PIN_MEMORY_OPTS=(0 1)
TOPK_RATIOS=(0.1 0.2)
UPDATE_INTERVALS=(2 4)
OVERLAP_STEPS=(1 0)

for pin_memory in "${PIN_MEMORY_OPTS[@]}"; do
for topk in "${TOPK_RATIOS[@]}"; do
for update in "${UPDATE_INTERVALS[@]}"; do
for overlap in "${OVERLAP_STEPS[@]}"; do
for ((trial=0; trial<$TRIALS; trial++)); do
# Generate a random port between 20000 and 65000
MASTER_PORT=$((20000 + RANDOM % 45000))
echo "[Trial $((trial+1))] pin_memory=$pin_memory, topk=$topk, update=$update, overlap_step=$overlap (MASTER_PORT=$MASTER_PORT)" | tee -a zf_benchmark.log
deepspeed --master_port $MASTER_PORT \
--num_gpus=$NGPUS \
zf_benchmark.py \
--hidden_dim $HIDDEN_SIZE \
--nlayers $NUM_LAYERS \
--iteration 5 \
--pin_memory_opts $pin_memory \
--topk_ratios $topk \
--update_intervals $update \
--overlap_steps $overlap | tee -a zf_benchmark.log
done
done
done
done
done
python output_table.py
150 changes: 150 additions & 0 deletions training/DeepSpeed-ZenFlow/benchmark/zf_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import argparse
import torch
import deepspeed.comm as dist
import time

import deepspeed

class SimpleModel(torch.nn.Module):

def __init__(self, hidden_dim, empty_grad=False, nlayers=1):
super(SimpleModel, self).__init__()
self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for _ in range(nlayers)])
if empty_grad:
self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

def forward(self, x, y):
for l in self.linears:
x = l(x)
return self.cross_entropy_loss(x, y)


def random_dataset(total_samples, hidden_dim, device, dtype):
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
return train_dataset


def random_dataloader(model, total_samples, hidden_dim, device, dtype):
batch_size = model.train_micro_batch_size_per_gpu()
train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
return train_loader


def run_model(model, config_dict, hidden_dim, dtype, pin_memory, topk_ratio, update_interval, overlap_step, iteration):

model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)


data_loader = random_dataloader(model=model,
total_samples=iteration,
hidden_dim=hidden_dim,
device=model.device,
dtype=dtype)

time_step_list = []
accumulation_step_time_list = []
update_step_time_list = []

dist.barrier()
for i, batch in enumerate(data_loader):
step_start_time = time.time()
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
step_end_time = time.time()
step_time = step_end_time - step_start_time
if dist.get_rank() == 0:
print(f"Step {i} time: {step_time*1000:.2f}ms")
if i >= update_interval:
time_step_list.append(step_time)
if (i + 1) % update_interval == 0:
update_step_time_list.append(step_time)
else:
accumulation_step_time_list.append(step_time)

if dist.get_rank() == 0:
with open("zenflow_report.log", "a") as f:
msg = f"{1 if pin_memory else 0}," \
f"{topk_ratio}," \
f"{update_interval}," \
f"{overlap_step}," \
f"{sum(accumulation_step_time_list) / len(accumulation_step_time_list):.2f}," \
f"{sum(update_step_time_list) / len(update_step_time_list):.2f}"
f.write(f"{msg}\n")
print(f"[Summary] pin_memory={pin_memory} topk_ratio={topk_ratio} update_interval={update_interval} overlap_step={overlap_step} avg_accumulation_step={sum(accumulation_step_time_list) * 1000 / len(accumulation_step_time_list):.2f}ms avg_update_step={sum(update_step_time_list) * 1000 / len(update_step_time_list):.2f}ms")

model.destroy()

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--nlayers", type=int, default=1)
parser.add_argument("--hidden_dim", type=int, default=1024)
parser.add_argument("--dtype", choices=['torch.bfloat16', 'torch.float16', 'torch.float32'], default='torch.bfloat16')
parser.add_argument("--iteration", type=int, default=5)
parser.add_argument("--local_rank", type=int, default=-1)

parser.add_argument("--pin_memory_opts", type=int, required=True)
parser.add_argument("--topk_ratios", type=float, required=True)
parser.add_argument("--update_intervals", type=int, required=True)
parser.add_argument("--overlap_steps", type=int, required=True)

# Optional: explicitly receive master_port (though deepspeed handles it via env)
parser.add_argument("--master_port", type=int, default=None)

args = parser.parse_args()
dtype = eval(args.dtype)


pin_memory = bool(args.pin_memory_opts)
topk_ratio = args.topk_ratios
update_interval = args.update_intervals
overlap_step = bool(args.overlap_steps)
total_iteration = args.iteration * update_interval

config_dict = {
"train_micro_batch_size_per_gpu": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-6
}
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": pin_memory
},
"zenflow": {
"topk_ratio": topk_ratio,
"update_interval": update_interval,
"full_warm_up_rounds": 0,
"overlap_step": overlap_step
},
},
"wall_clock_breakdown": True,
"zero_allow_untested_optimizer": True
}

if dtype == torch.float16:
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
elif dtype == torch.bfloat16:
config_dict["bf16"] = {"enabled": True}

model = SimpleModel(args.hidden_dim, nlayers=args.nlayers)
run_model(model, config_dict, args.hidden_dim, dtype,
pin_memory, topk_ratio, update_interval, overlap_step,
total_iteration)


if __name__ == "__main__":
main()
80 changes: 80 additions & 0 deletions training/DeepSpeed-ZenFlow/finetuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@

# ZenFlow Llama-2 Fine-Tuning Example

This project demonstrates how to fine-tune a [Llama-2](https://huggingface.co/meta-llama) model using [DeepSpeed](https://www.deepspeed.ai/) with **ZenFlow**, a stall-free offloading engine for large-scale model training.

## Quick Start

1. **Install dependencies**

```bash
pip install -r requirements.txt
```

2. **Configure training**

Edit `zf_config.json` to enable ZenFlow:

```json
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"zenflow": {
"topk_ratio": 0.1,
"update_interval": 4,
"full_warm_up_rounds": 0,
"overlap_step": true
}
}
```

3. **Run fine-tuning**

```bash
bash finetune_llama.sh
```

This runs LLaMA-2 fine-tuning using DeepSpeed + ZenFlow, saving checkpoints to `./alpaca_output`.

## Example Output

Below is a sample log showing step time and loss values. You can see significant speedup after the first full step:

```
ZenFlowCPUAdam initialized with overlap step.
Step 5, Loss: 1.2599, Time: 719.58ms
Step 6, Loss: 0.9847, Time: 702.81ms
Step 7, Loss: 0.6220, Time: 705.50ms
Step 8, Loss: 0.5173, Time: 1912.92ms
Step 9, Loss: 0.4557, Time: 890.60ms
Step 10, Loss: 0.3882, Time: 740.11ms
Step 11, Loss: 0.3627, Time: 731.95ms
Step 12, Loss: 0.3341, Time: 2221.18ms
Step 13, Loss: 0.2453, Time: 1061.80ms
```

ZenFlow reduces optimizer-induced stalls by overlapping CPU computation and GPU execution.

## Notes

- To change model, batch size, or epochs, modify `finetune_llama.sh`.
- All DeepSpeed and ZenFlow options are controlled via `zf_config.json`.

## Citation

To cite DeepSpeed Chat, please cite our [arxiv report](https://arxiv.org/abs/2505.12242):

```bib
@misc{lan2025zenflowenablingstallfreeoffloading,
title={ZenFlow: Enabling Stall-Free Offloading Training via Asynchronous Updates},
author={Tingfeng Lan and Yusen Wu and Bin Ma and Zhaoyuan Su and Rui Yang and Tekin Bicer and Dong Li and Yue Cheng},
year={2025},
eprint={2505.12242},
archivePrefix={arXiv},
primaryClass={cs.DC},
url={https://arxiv.org/abs/2505.12242},
}
```
Loading