Skip to content

Commit ac7ea5c

Browse files
committed
Add benchmark scripts and README for ZenFlow
- Introduced `zf_benchmark.py` for model offloading benchmarking with DeepSpeed. - Added `output_table.py` to parse and display benchmark results in a tabular format. - Created `run_benchmark.sh` to automate benchmark runs with various configurations. Signed-off-by: Tingfeng Lan <[email protected]>
1 parent b99d653 commit ac7ea5c

File tree

5 files changed

+312
-0
lines changed

5 files changed

+312
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# ZenFlow Benchmark Example
2+
3+
4+
Please install DeepSpeed via pip install deepspeed if you haven't already done so.
5+
6+
```bash
7+
pip install -r requirements.txt
8+
```
9+
10+
11+
The script `zf_benchmark.py ` demonstrates how to offload the state of a model. Here is the example usage.
12+
13+
```python
14+
$ deepspeed --num_gpus=4 zf_benchmark.py --hidden_dim 4096 --nlayers 4 --iteration 5 --pin_memory_opts 1 --topk_ratios 0.1 --update_intervals 2 --overlap_steps
15+
...
16+
time (ms) | selective_optimizer_update: 19.20 | selective_optimizer_process: 28.80 | selective_optimizer_sync: 0.05
17+
time (ms) | fwd_microstep: 54.76 | bwd_microstep: 122.95 | bwd_inner_microstep: 12.22 | bwd_allreduce_microstep: 103.64 | step_microstep: 0.34
18+
Step 0 time: 178.66ms
19+
time (ms) | optimizer_allgather: 26.19 | optimizer_gradients: 26.06 | optimizer_step: 128.20
20+
time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 0.57 | selective_optimizer_step: 1.48 | selective_optimizer_sync: 0.00
21+
time (ms) | fwd_microstep: 0.38 | bwd_microstep: 57.88 | bwd_inner_microstep: 1.06 | bwd_allreduce_microstep: 56.50 | step_microstep: 183.27
22+
time (ms) | fwd: 55.15 | bwd: 180.82 | bwd_inner: 13.28 | bwd_allreduce: 160.15 | step: 183.61
23+
Step 1 time: 242.16ms
24+
time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 1.58 | selective_optimizer_step: 0.00 | selective_optimizer_sync: 0.00
25+
time (ms) | fwd_microstep: 0.30 | bwd_microstep: 16.73 | bwd_inner_microstep: 1.39 | bwd_allreduce_microstep: 14.96 | step_microstep: 0.20
26+
Step 2 time: 17.60ms
27+
time (ms) | optimizer_allgather: 0.65 | optimizer_gradients: 16.95 | optimizer_step: 108.45
28+
time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 0.56 | selective_optimizer_step: 1.42 | selective_optimizer_sync: 0.00
29+
time (ms) | fwd_microstep: 0.29 | bwd_microstep: 36.65 | bwd_inner_microstep: 0.95 | bwd_allreduce_microstep: 35.51 | step_microstep: 128.57
30+
time (ms) | fwd: 0.59 | bwd: 53.39 | bwd_inner: 2.33 | bwd_allreduce: 50.48 | step: 128.77
31+
Step 3 time: 166.10ms
32+
time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 1.57 | selective_optimizer_step: 0.00 | selective_optimizer_sync: 0.00
33+
time (ms) | fwd_microstep: 0.31 | bwd_microstep: 15.47 | bwd_inner_microstep: 1.33 | bwd_allreduce_microstep: 13.97 | step_microstep: 0.23
34+
...
35+
[Summary] pin_memory=False topk_ratio=0.1 update_interval=2 overlap_step=False avg_accumulation_step=16.77ms avg_update_step=171.38ms
36+
```
37+
38+
`run_benchmark.sh` shows how to run the script with different configurations. The script outputs the time for offloading and loading the states.
39+
40+
```python
41+
$ ./run_benchmark.sh
42+
...
43+
+---------+--------------+--------------+-------------------+----------------+-------------+------------+-----------+-----------+--------------------------------+
44+
| trial | pin_memory | topk_ratio | update_interval | overlap_step | num_steps | avg_step | avg_bwd | avg_fwd | avg_selective_optimizer_step |
45+
|---------+--------------+--------------+-------------------+----------------+-------------+------------+-----------+-----------+--------------------------------|
46+
| 1 | False | 0.1 | 2 | False | 30 | 24.0153 | 12.8377 | 1.91733 | 0.247 |
47+
| 1 | False | 0.1 | 2 | True | 28 | 805.425 | 22.5604 | 1.96821 | 0.345714 |
48+
| 1 | False | 0.1 | 4 | False | 50 | 14.2108 | 10.9072 | 1.2436 | 0.1484 |
49+
| 1 | False | 0.1 | 4 | True | 48 | 459.326 | 16.0385 | 1.30125 | 0.221667 |
50+
| 1 | False | 0.2 | 2 | False | 30 | 22.6567 | 12.6463 | 2.421 | 0.346 |
51+
| 1 | False | 0.2 | 2 | True | 28 | 817.919 | 22.1079 | 2.06179 | 0.450714 |
52+
| 1 | False | 0.2 | 4 | False | 50 | 14.12 | 9.4714 | 1.1766 | 0.2072 |
53+
| 1 | False | 0.2 | 4 | True | 48 | 471.339 | 15.945 | 1.2675 | 0.262292 |...
54+
```
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import re
2+
from collections import defaultdict
3+
import pandas as pd
4+
from tabulate import tabulate
5+
6+
def parse_log_file(log_file_path):
7+
with open(log_file_path, 'r') as f:
8+
lines = f.readlines()
9+
10+
# Regex patterns
11+
trial_header_re = re.compile(
12+
r"\[Trial (\d+)] pin_memory=(\d), topk=([\d.]+), update=(\d+), overlap_step=(\d+) \(MASTER_PORT=\d+\)"
13+
)
14+
time_metrics_re = re.compile(r"\|\s*([^:|]+):\s*([\d.]+)")
15+
16+
trials = []
17+
current_config = None
18+
current_step_metrics = []
19+
20+
def finalize_trial():
21+
if current_config and current_step_metrics:
22+
# Get all unique keys
23+
all_keys = set()
24+
for step in current_step_metrics:
25+
all_keys.update(step.keys())
26+
# Aggregate and average
27+
agg = {k: 0.0 for k in all_keys}
28+
for step in current_step_metrics:
29+
for k in all_keys:
30+
agg[k] += step.get(k, 0.0)
31+
avg = {f"avg_{k}": agg[k] / len(current_step_metrics) for k in all_keys}
32+
trials.append({**current_config, **avg, "num_steps": len(current_step_metrics)})
33+
34+
for line in lines:
35+
header_match = trial_header_re.search(line)
36+
if header_match:
37+
finalize_trial()
38+
trial_id, pin_memory, topk, update, overlap = header_match.groups()
39+
current_config = {
40+
"trial": int(trial_id),
41+
"pin_memory": bool(int(pin_memory)),
42+
"topk_ratio": float(topk),
43+
"update_interval": int(update),
44+
"overlap_step": bool(int(overlap))
45+
}
46+
current_step_metrics = []
47+
continue
48+
49+
if "[Rank 0]" in line and "time (ms)" in line:
50+
metrics = {k.strip(): float(v) for k, v in time_metrics_re.findall(line)}
51+
current_step_metrics.append(metrics)
52+
53+
finalize_trial()
54+
return pd.DataFrame(trials)
55+
56+
if __name__ == "__main__":
57+
58+
log_file = "zf_benchmark.log"
59+
df = parse_log_file(log_file)
60+
df = df.sort_values(by=["topk_ratio", "update_interval", "overlap_step", "pin_memory"])
61+
cols_to_display = [
62+
"trial", "topk_ratio", "update_interval", "overlap_step", "pin_memory", "num_steps",
63+
"avg_step", "avg_bwd", "avg_fwd", "avg_selective_optimizer_step"
64+
]
65+
print(tabulate(df[cols_to_display], headers="keys", tablefmt="psql", showindex=False))
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
torch>=2.5.1
2+
deepspeed>=0.16.0
3+
datasets>=2.14.1
4+
transformers>=4.37.2
5+
numpy>=1.21.0
6+
tabulate
7+
pandas
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/bin/bash
2+
3+
NGPUS=2
4+
HIDDEN_SIZE=4096
5+
NUM_LAYERS=4
6+
TRIALS=1
7+
8+
PIN_MEMORY_OPTS=(0 1)
9+
TOPK_RATIOS=(0.1 0.2)
10+
UPDATE_INTERVALS=(2 4)
11+
OVERLAP_STEPS=(1 0)
12+
13+
for pin_memory in "${PIN_MEMORY_OPTS[@]}"; do
14+
for topk in "${TOPK_RATIOS[@]}"; do
15+
for update in "${UPDATE_INTERVALS[@]}"; do
16+
for overlap in "${OVERLAP_STEPS[@]}"; do
17+
for ((trial=0; trial<$TRIALS; trial++)); do
18+
# Generate a random port between 20000 and 65000
19+
MASTER_PORT=$((20000 + RANDOM % 45000))
20+
echo "[Trial $((trial+1))] pin_memory=$pin_memory, topk=$topk, update=$update, overlap_step=$overlap (MASTER_PORT=$MASTER_PORT)" | tee -a zf_benchmark.log
21+
deepspeed --master_port $MASTER_PORT \
22+
--num_gpus=$NGPUS \
23+
zf_benchmark.py \
24+
--hidden_dim $HIDDEN_SIZE \
25+
--nlayers $NUM_LAYERS \
26+
--iteration 5 \
27+
--pin_memory_opts $pin_memory \
28+
--topk_ratios $topk \
29+
--update_intervals $update \
30+
--overlap_steps $overlap | tee -a zf_benchmark.log
31+
done
32+
done
33+
done
34+
done
35+
done
36+
python output_table.py
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import argparse
7+
import torch
8+
import deepspeed.comm as dist
9+
import time
10+
11+
import deepspeed
12+
13+
class SimpleModel(torch.nn.Module):
14+
15+
def __init__(self, hidden_dim, empty_grad=False, nlayers=1):
16+
super(SimpleModel, self).__init__()
17+
self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for _ in range(nlayers)])
18+
if empty_grad:
19+
self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
20+
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
21+
22+
def forward(self, x, y):
23+
for l in self.linears:
24+
x = l(x)
25+
return self.cross_entropy_loss(x, y)
26+
27+
28+
def random_dataset(total_samples, hidden_dim, device, dtype):
29+
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
30+
train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
31+
train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
32+
return train_dataset
33+
34+
35+
def random_dataloader(model, total_samples, hidden_dim, device, dtype):
36+
batch_size = model.train_micro_batch_size_per_gpu()
37+
train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype)
38+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
39+
return train_loader
40+
41+
42+
def run_model(model, config_dict, hidden_dim, dtype, pin_memory, topk_ratio, update_interval, overlap_step, iteration):
43+
44+
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
45+
46+
47+
data_loader = random_dataloader(model=model,
48+
total_samples=iteration,
49+
hidden_dim=hidden_dim,
50+
device=model.device,
51+
dtype=dtype)
52+
53+
time_step_list = []
54+
accumulation_step_time_list = []
55+
update_step_time_list = []
56+
57+
dist.barrier()
58+
for i, batch in enumerate(data_loader):
59+
step_start_time = time.time()
60+
loss = model(batch[0], batch[1])
61+
model.backward(loss)
62+
model.step()
63+
step_end_time = time.time()
64+
step_time = step_end_time - step_start_time
65+
if dist.get_rank() == 0:
66+
print(f"Step {i} time: {step_time*1000:.2f}ms")
67+
if i >= update_interval:
68+
time_step_list.append(step_time)
69+
if (i + 1) % update_interval == 0:
70+
update_step_time_list.append(step_time)
71+
else:
72+
accumulation_step_time_list.append(step_time)
73+
74+
if dist.get_rank() == 0:
75+
with open("zenflow_report.log", "a") as f:
76+
msg = f"{1 if pin_memory else 0}," \
77+
f"{topk_ratio}," \
78+
f"{update_interval}," \
79+
f"{overlap_step}," \
80+
f"{sum(accumulation_step_time_list) / len(accumulation_step_time_list):.2f}," \
81+
f"{sum(update_step_time_list) / len(update_step_time_list):.2f}"
82+
f.write(f"{msg}\n")
83+
print(f"[Summary] pin_memory={pin_memory} topk_ratio={topk_ratio} update_interval={update_interval} overlap_step={overlap_step} avg_accumulation_step={sum(accumulation_step_time_list) * 1000 / len(accumulation_step_time_list):.2f}ms avg_update_step={sum(update_step_time_list) * 1000 / len(update_step_time_list):.2f}ms")
84+
85+
model.destroy()
86+
87+
def main():
88+
parser = argparse.ArgumentParser()
89+
parser.add_argument("--nlayers", type=int, default=1)
90+
parser.add_argument("--hidden_dim", type=int, default=1024)
91+
parser.add_argument("--dtype", choices=['torch.bfloat16', 'torch.float16', 'torch.float32'], default='torch.bfloat16')
92+
parser.add_argument("--iteration", type=int, default=5)
93+
parser.add_argument("--local_rank", type=int, default=-1)
94+
95+
parser.add_argument("--pin_memory_opts", type=int, required=True)
96+
parser.add_argument("--topk_ratios", type=float, required=True)
97+
parser.add_argument("--update_intervals", type=int, required=True)
98+
parser.add_argument("--overlap_steps", type=int, required=True)
99+
100+
# Optional: explicitly receive master_port (though deepspeed handles it via env)
101+
parser.add_argument("--master_port", type=int, default=None)
102+
103+
args = parser.parse_args()
104+
dtype = eval(args.dtype)
105+
106+
107+
pin_memory = bool(args.pin_memory_opts)
108+
topk_ratio = args.topk_ratios
109+
update_interval = args.update_intervals
110+
overlap_step = bool(args.overlap_steps)
111+
total_iteration = args.iteration * update_interval
112+
113+
config_dict = {
114+
"train_micro_batch_size_per_gpu": 1,
115+
"optimizer": {
116+
"type": "Adam",
117+
"params": {
118+
"lr": 1e-6
119+
}
120+
},
121+
"zero_optimization": {
122+
"stage": 2,
123+
"offload_optimizer": {
124+
"device": "cpu",
125+
"pin_memory": pin_memory
126+
},
127+
"zenflow": {
128+
"topk_ratio": topk_ratio,
129+
"update_interval": update_interval,
130+
"full_warm_up_rounds": 0,
131+
"overlap_step": overlap_step
132+
},
133+
},
134+
"wall_clock_breakdown": True,
135+
"zero_allow_untested_optimizer": True
136+
}
137+
138+
if dtype == torch.float16:
139+
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
140+
elif dtype == torch.bfloat16:
141+
config_dict["bf16"] = {"enabled": True}
142+
143+
model = SimpleModel(args.hidden_dim, nlayers=args.nlayers)
144+
run_model(model, config_dict, args.hidden_dim, dtype,
145+
pin_memory, topk_ratio, update_interval, overlap_step,
146+
total_iteration)
147+
148+
149+
if __name__ == "__main__":
150+
main()

0 commit comments

Comments
 (0)