-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add Benchmarking and Fine-Tuning Support for ZenFlow #982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
sfc-gh-truwase
merged 5 commits into
deepspeedai:master
from
Antlera:zenflow_z1_2_example
Aug 16, 2025
Merged
Changes from 2 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
ac7ea5c
Add benchmark scripts and README for ZenFlow
Antlera 0528aed
Add Llama-2 fine-tuning scripts and configuration for ZenFlow
Antlera 0309313
Add explanation tips for interpreting benchmark results in README
Antlera 0b18cca
Add guidance on step/latency interpretation
Antlera c4946a1
Merge branch 'master' into zenflow_z1_2_example
Antlera File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| # ZenFlow Benchmark Example | ||
|
|
||
|
|
||
| Please install DeepSpeed via pip install deepspeed if you haven't already done so. | ||
|
|
||
| ```bash | ||
| pip install -r requirements.txt | ||
| ``` | ||
|
|
||
|
|
||
| The script `zf_benchmark.py ` demonstrates how to offload the state of a model. Here is the example usage. | ||
|
|
||
| ```python | ||
| $ deepspeed --num_gpus=4 zf_benchmark.py --hidden_dim 4096 --nlayers 4 --iteration 5 --pin_memory_opts 1 --topk_ratios 0.1 --update_intervals 2 --overlap_steps | ||
| ... | ||
| time (ms) | selective_optimizer_update: 19.20 | selective_optimizer_process: 28.80 | selective_optimizer_sync: 0.05 | ||
| time (ms) | fwd_microstep: 54.76 | bwd_microstep: 122.95 | bwd_inner_microstep: 12.22 | bwd_allreduce_microstep: 103.64 | step_microstep: 0.34 | ||
| Step 0 time: 178.66ms | ||
| time (ms) | optimizer_allgather: 26.19 | optimizer_gradients: 26.06 | optimizer_step: 128.20 | ||
| time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 0.57 | selective_optimizer_step: 1.48 | selective_optimizer_sync: 0.00 | ||
| time (ms) | fwd_microstep: 0.38 | bwd_microstep: 57.88 | bwd_inner_microstep: 1.06 | bwd_allreduce_microstep: 56.50 | step_microstep: 183.27 | ||
| time (ms) | fwd: 55.15 | bwd: 180.82 | bwd_inner: 13.28 | bwd_allreduce: 160.15 | step: 183.61 | ||
| Step 1 time: 242.16ms | ||
| time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 1.58 | selective_optimizer_step: 0.00 | selective_optimizer_sync: 0.00 | ||
| time (ms) | fwd_microstep: 0.30 | bwd_microstep: 16.73 | bwd_inner_microstep: 1.39 | bwd_allreduce_microstep: 14.96 | step_microstep: 0.20 | ||
| Step 2 time: 17.60ms | ||
| time (ms) | optimizer_allgather: 0.65 | optimizer_gradients: 16.95 | optimizer_step: 108.45 | ||
| time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 0.56 | selective_optimizer_step: 1.42 | selective_optimizer_sync: 0.00 | ||
| time (ms) | fwd_microstep: 0.29 | bwd_microstep: 36.65 | bwd_inner_microstep: 0.95 | bwd_allreduce_microstep: 35.51 | step_microstep: 128.57 | ||
| time (ms) | fwd: 0.59 | bwd: 53.39 | bwd_inner: 2.33 | bwd_allreduce: 50.48 | step: 128.77 | ||
| Step 3 time: 166.10ms | ||
| time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 1.57 | selective_optimizer_step: 0.00 | selective_optimizer_sync: 0.00 | ||
| time (ms) | fwd_microstep: 0.31 | bwd_microstep: 15.47 | bwd_inner_microstep: 1.33 | bwd_allreduce_microstep: 13.97 | step_microstep: 0.23 | ||
| ... | ||
| [Summary] pin_memory=False topk_ratio=0.1 update_interval=2 overlap_step=False avg_accumulation_step=16.77ms avg_update_step=171.38ms | ||
| ``` | ||
|
|
||
| `run_benchmark.sh` shows how to run the script with different configurations. The script outputs the time for offloading and loading the states. | ||
|
|
||
| ```python | ||
| $ ./run_benchmark.sh | ||
| ... | ||
| +---------+--------------+--------------+-------------------+----------------+-------------+------------+-----------+-----------+--------------------------------+ | ||
| | trial | pin_memory | topk_ratio | update_interval | overlap_step | num_steps | avg_step | avg_bwd | avg_fwd | avg_selective_optimizer_step | | ||
| |---------+--------------+--------------+-------------------+----------------+-------------+------------+-----------+-----------+--------------------------------| | ||
| | 1 | False | 0.1 | 2 | False | 30 | 24.0153 | 12.8377 | 1.91733 | 0.247 | | ||
| | 1 | False | 0.1 | 2 | True | 28 | 805.425 | 22.5604 | 1.96821 | 0.345714 | | ||
| | 1 | False | 0.1 | 4 | False | 50 | 14.2108 | 10.9072 | 1.2436 | 0.1484 | | ||
| | 1 | False | 0.1 | 4 | True | 48 | 459.326 | 16.0385 | 1.30125 | 0.221667 | | ||
| | 1 | False | 0.2 | 2 | False | 30 | 22.6567 | 12.6463 | 2.421 | 0.346 | | ||
| | 1 | False | 0.2 | 2 | True | 28 | 817.919 | 22.1079 | 2.06179 | 0.450714 | | ||
| | 1 | False | 0.2 | 4 | False | 50 | 14.12 | 9.4714 | 1.1766 | 0.2072 | | ||
| | 1 | False | 0.2 | 4 | True | 48 | 471.339 | 15.945 | 1.2675 | 0.262292 |... | ||
| ``` | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| import re | ||
| from collections import defaultdict | ||
| import pandas as pd | ||
| from tabulate import tabulate | ||
|
|
||
| def parse_log_file(log_file_path): | ||
| with open(log_file_path, 'r') as f: | ||
| lines = f.readlines() | ||
|
|
||
| # Regex patterns | ||
| trial_header_re = re.compile( | ||
| r"\[Trial (\d+)] pin_memory=(\d), topk=([\d.]+), update=(\d+), overlap_step=(\d+) \(MASTER_PORT=\d+\)" | ||
| ) | ||
| time_metrics_re = re.compile(r"\|\s*([^:|]+):\s*([\d.]+)") | ||
|
|
||
| trials = [] | ||
| current_config = None | ||
| current_step_metrics = [] | ||
|
|
||
| def finalize_trial(): | ||
| if current_config and current_step_metrics: | ||
| # Get all unique keys | ||
| all_keys = set() | ||
| for step in current_step_metrics: | ||
| all_keys.update(step.keys()) | ||
| # Aggregate and average | ||
| agg = {k: 0.0 for k in all_keys} | ||
| for step in current_step_metrics: | ||
| for k in all_keys: | ||
| agg[k] += step.get(k, 0.0) | ||
| avg = {f"avg_{k}": agg[k] / len(current_step_metrics) for k in all_keys} | ||
| trials.append({**current_config, **avg, "num_steps": len(current_step_metrics)}) | ||
|
|
||
| for line in lines: | ||
| header_match = trial_header_re.search(line) | ||
| if header_match: | ||
| finalize_trial() | ||
| trial_id, pin_memory, topk, update, overlap = header_match.groups() | ||
| current_config = { | ||
| "trial": int(trial_id), | ||
| "pin_memory": bool(int(pin_memory)), | ||
| "topk_ratio": float(topk), | ||
| "update_interval": int(update), | ||
| "overlap_step": bool(int(overlap)) | ||
| } | ||
| current_step_metrics = [] | ||
| continue | ||
|
|
||
| if "[Rank 0]" in line and "time (ms)" in line: | ||
| metrics = {k.strip(): float(v) for k, v in time_metrics_re.findall(line)} | ||
| current_step_metrics.append(metrics) | ||
|
|
||
| finalize_trial() | ||
| return pd.DataFrame(trials) | ||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
| log_file = "zf_benchmark.log" | ||
| df = parse_log_file(log_file) | ||
| df = df.sort_values(by=["topk_ratio", "update_interval", "overlap_step", "pin_memory"]) | ||
| cols_to_display = [ | ||
| "trial", "topk_ratio", "update_interval", "overlap_step", "pin_memory", "num_steps", | ||
| "avg_step", "avg_bwd", "avg_fwd", "avg_selective_optimizer_step" | ||
| ] | ||
| print(tabulate(df[cols_to_display], headers="keys", tablefmt="psql", showindex=False)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| torch>=2.5.1 | ||
| deepspeed>=0.16.0 | ||
| datasets>=2.14.1 | ||
| transformers>=4.37.2 | ||
| numpy>=1.21.0 | ||
| tabulate | ||
| pandas |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| #!/bin/bash | ||
|
|
||
| NGPUS=2 | ||
| HIDDEN_SIZE=4096 | ||
| NUM_LAYERS=4 | ||
| TRIALS=1 | ||
|
|
||
| PIN_MEMORY_OPTS=(0 1) | ||
| TOPK_RATIOS=(0.1 0.2) | ||
| UPDATE_INTERVALS=(2 4) | ||
| OVERLAP_STEPS=(1 0) | ||
|
|
||
| for pin_memory in "${PIN_MEMORY_OPTS[@]}"; do | ||
| for topk in "${TOPK_RATIOS[@]}"; do | ||
| for update in "${UPDATE_INTERVALS[@]}"; do | ||
| for overlap in "${OVERLAP_STEPS[@]}"; do | ||
| for ((trial=0; trial<$TRIALS; trial++)); do | ||
| # Generate a random port between 20000 and 65000 | ||
| MASTER_PORT=$((20000 + RANDOM % 45000)) | ||
| echo "[Trial $((trial+1))] pin_memory=$pin_memory, topk=$topk, update=$update, overlap_step=$overlap (MASTER_PORT=$MASTER_PORT)" | tee -a zf_benchmark.log | ||
| deepspeed --master_port $MASTER_PORT \ | ||
| --num_gpus=$NGPUS \ | ||
| zf_benchmark.py \ | ||
| --hidden_dim $HIDDEN_SIZE \ | ||
| --nlayers $NUM_LAYERS \ | ||
| --iteration 5 \ | ||
| --pin_memory_opts $pin_memory \ | ||
| --topk_ratios $topk \ | ||
| --update_intervals $update \ | ||
| --overlap_steps $overlap | tee -a zf_benchmark.log | ||
| done | ||
| done | ||
| done | ||
| done | ||
| done | ||
| python output_table.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
|
|
||
| import argparse | ||
| import torch | ||
| import deepspeed.comm as dist | ||
| import time | ||
|
|
||
| import deepspeed | ||
|
|
||
| class SimpleModel(torch.nn.Module): | ||
|
|
||
| def __init__(self, hidden_dim, empty_grad=False, nlayers=1): | ||
| super(SimpleModel, self).__init__() | ||
| self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for _ in range(nlayers)]) | ||
| if empty_grad: | ||
| self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) | ||
| self.cross_entropy_loss = torch.nn.CrossEntropyLoss() | ||
|
|
||
| def forward(self, x, y): | ||
| for l in self.linears: | ||
| x = l(x) | ||
| return self.cross_entropy_loss(x, y) | ||
|
|
||
|
|
||
| def random_dataset(total_samples, hidden_dim, device, dtype): | ||
| train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype) | ||
| train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim) | ||
| train_dataset = torch.utils.data.TensorDataset(train_data, train_label) | ||
| return train_dataset | ||
|
|
||
|
|
||
| def random_dataloader(model, total_samples, hidden_dim, device, dtype): | ||
| batch_size = model.train_micro_batch_size_per_gpu() | ||
| train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype) | ||
| train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) | ||
| return train_loader | ||
|
|
||
|
|
||
| def run_model(model, config_dict, hidden_dim, dtype, pin_memory, topk_ratio, update_interval, overlap_step, iteration): | ||
|
|
||
| model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) | ||
|
|
||
|
|
||
| data_loader = random_dataloader(model=model, | ||
| total_samples=iteration, | ||
| hidden_dim=hidden_dim, | ||
| device=model.device, | ||
| dtype=dtype) | ||
|
|
||
| time_step_list = [] | ||
| accumulation_step_time_list = [] | ||
| update_step_time_list = [] | ||
|
|
||
| dist.barrier() | ||
| for i, batch in enumerate(data_loader): | ||
| step_start_time = time.time() | ||
| loss = model(batch[0], batch[1]) | ||
| model.backward(loss) | ||
| model.step() | ||
| step_end_time = time.time() | ||
| step_time = step_end_time - step_start_time | ||
| if dist.get_rank() == 0: | ||
| print(f"Step {i} time: {step_time*1000:.2f}ms") | ||
| if i >= update_interval: | ||
| time_step_list.append(step_time) | ||
| if (i + 1) % update_interval == 0: | ||
| update_step_time_list.append(step_time) | ||
| else: | ||
| accumulation_step_time_list.append(step_time) | ||
|
|
||
| if dist.get_rank() == 0: | ||
| with open("zenflow_report.log", "a") as f: | ||
| msg = f"{1 if pin_memory else 0}," \ | ||
| f"{topk_ratio}," \ | ||
| f"{update_interval}," \ | ||
| f"{overlap_step}," \ | ||
| f"{sum(accumulation_step_time_list) / len(accumulation_step_time_list):.2f}," \ | ||
| f"{sum(update_step_time_list) / len(update_step_time_list):.2f}" | ||
| f.write(f"{msg}\n") | ||
| print(f"[Summary] pin_memory={pin_memory} topk_ratio={topk_ratio} update_interval={update_interval} overlap_step={overlap_step} avg_accumulation_step={sum(accumulation_step_time_list) * 1000 / len(accumulation_step_time_list):.2f}ms avg_update_step={sum(update_step_time_list) * 1000 / len(update_step_time_list):.2f}ms") | ||
|
|
||
| model.destroy() | ||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--nlayers", type=int, default=1) | ||
| parser.add_argument("--hidden_dim", type=int, default=1024) | ||
| parser.add_argument("--dtype", choices=['torch.bfloat16', 'torch.float16', 'torch.float32'], default='torch.bfloat16') | ||
| parser.add_argument("--iteration", type=int, default=5) | ||
| parser.add_argument("--local_rank", type=int, default=-1) | ||
|
|
||
| parser.add_argument("--pin_memory_opts", type=int, required=True) | ||
| parser.add_argument("--topk_ratios", type=float, required=True) | ||
| parser.add_argument("--update_intervals", type=int, required=True) | ||
| parser.add_argument("--overlap_steps", type=int, required=True) | ||
|
|
||
| # Optional: explicitly receive master_port (though deepspeed handles it via env) | ||
| parser.add_argument("--master_port", type=int, default=None) | ||
|
|
||
| args = parser.parse_args() | ||
| dtype = eval(args.dtype) | ||
|
|
||
|
|
||
| pin_memory = bool(args.pin_memory_opts) | ||
| topk_ratio = args.topk_ratios | ||
| update_interval = args.update_intervals | ||
| overlap_step = bool(args.overlap_steps) | ||
| total_iteration = args.iteration * update_interval | ||
|
|
||
| config_dict = { | ||
| "train_micro_batch_size_per_gpu": 1, | ||
| "optimizer": { | ||
| "type": "Adam", | ||
| "params": { | ||
| "lr": 1e-6 | ||
| } | ||
| }, | ||
| "zero_optimization": { | ||
| "stage": 2, | ||
| "offload_optimizer": { | ||
| "device": "cpu", | ||
| "pin_memory": pin_memory | ||
| }, | ||
| "zenflow": { | ||
| "topk_ratio": topk_ratio, | ||
| "update_interval": update_interval, | ||
| "full_warm_up_rounds": 0, | ||
| "overlap_step": overlap_step | ||
| }, | ||
| }, | ||
| "wall_clock_breakdown": True, | ||
| "zero_allow_untested_optimizer": True | ||
| } | ||
|
|
||
| if dtype == torch.float16: | ||
| config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} | ||
| elif dtype == torch.bfloat16: | ||
| config_dict["bf16"] = {"enabled": True} | ||
|
|
||
| model = SimpleModel(args.hidden_dim, nlayers=args.nlayers) | ||
| run_model(model, config_dict, args.hidden_dim, dtype, | ||
| pin_memory, topk_ratio, update_interval, overlap_step, | ||
| total_iteration) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
|
|
||
| # ZenFlow Llama-2 Fine-Tuning Example | ||
|
|
||
| This project demonstrates how to fine-tune a [Llama-2](https://huggingface.co/meta-llama) model using [DeepSpeed](https://www.deepspeed.ai/) with **ZenFlow**, a stall-free offloading engine for large-scale model training. | ||
|
|
||
| ## Quick Start | ||
|
|
||
| 1. **Install dependencies** | ||
|
|
||
| ```bash | ||
| pip install -r requirements.txt | ||
| ``` | ||
|
|
||
| 2. **Configure training** | ||
|
|
||
| Edit `zf_config.json` to enable ZenFlow: | ||
|
|
||
| ```json | ||
| "zero_optimization": { | ||
| "stage": 2, | ||
| "offload_optimizer": { | ||
| "device": "cpu", | ||
| "pin_memory": true | ||
| }, | ||
| "zenflow": { | ||
| "topk_ratio": 0.1, | ||
| "update_interval": 4, | ||
| "full_warm_up_rounds": 0, | ||
| "overlap_step": true | ||
| } | ||
| } | ||
| ``` | ||
|
|
||
| 3. **Run fine-tuning** | ||
|
|
||
| ```bash | ||
| bash finetune_llama.sh | ||
| ``` | ||
|
|
||
| This runs LLaMA-2 fine-tuning using DeepSpeed + ZenFlow, saving checkpoints to `./alpaca_output`. | ||
|
|
||
| ## Example Output | ||
|
|
||
| Below is a sample log showing step time and loss values. You can see significant speedup after the first full step: | ||
Antlera marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| ``` | ||
| ZenFlowCPUAdam initialized with overlap step. | ||
| Step 5, Loss: 1.2599, Time: 719.58ms | ||
| Step 6, Loss: 0.9847, Time: 702.81ms | ||
| Step 7, Loss: 0.6220, Time: 705.50ms | ||
| Step 8, Loss: 0.5173, Time: 1912.92ms | ||
| Step 9, Loss: 0.4557, Time: 890.60ms | ||
| Step 10, Loss: 0.3882, Time: 740.11ms | ||
| Step 11, Loss: 0.3627, Time: 731.95ms | ||
| Step 12, Loss: 0.3341, Time: 2221.18ms | ||
| Step 13, Loss: 0.2453, Time: 1061.80ms | ||
| ``` | ||
|
|
||
| ZenFlow reduces optimizer-induced stalls by overlapping CPU computation and GPU execution. | ||
|
|
||
| ## Notes | ||
|
|
||
| - To change model, batch size, or epochs, modify `finetune_llama.sh`. | ||
| - All DeepSpeed and ZenFlow options are controlled via `zf_config.json`. | ||
|
|
||
| ## Citation | ||
|
|
||
| To cite DeepSpeed Chat, please cite our [arxiv report](https://arxiv.org/abs/2505.12242): | ||
|
|
||
| ```bib | ||
| @misc{lan2025zenflowenablingstallfreeoffloading, | ||
| title={ZenFlow: Enabling Stall-Free Offloading Training via Asynchronous Updates}, | ||
| author={Tingfeng Lan and Yusen Wu and Bin Ma and Zhaoyuan Su and Rui Yang and Tekin Bicer and Dong Li and Yue Cheng}, | ||
| year={2025}, | ||
| eprint={2505.12242}, | ||
| archivePrefix={arXiv}, | ||
| primaryClass={cs.DC}, | ||
| url={https://arxiv.org/abs/2505.12242}, | ||
| } | ||
| ``` | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.