|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +# DeepSpeed Team |
| 5 | + |
| 6 | +import argparse |
| 7 | +import torch |
| 8 | +import deepspeed.comm as dist |
| 9 | +import time |
| 10 | + |
| 11 | +import deepspeed |
| 12 | + |
| 13 | +class SimpleModel(torch.nn.Module): |
| 14 | + |
| 15 | + def __init__(self, hidden_dim, empty_grad=False, nlayers=1): |
| 16 | + super(SimpleModel, self).__init__() |
| 17 | + self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for _ in range(nlayers)]) |
| 18 | + if empty_grad: |
| 19 | + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) |
| 20 | + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() |
| 21 | + |
| 22 | + def forward(self, x, y): |
| 23 | + for l in self.linears: |
| 24 | + x = l(x) |
| 25 | + return self.cross_entropy_loss(x, y) |
| 26 | + |
| 27 | + |
| 28 | +def random_dataset(total_samples, hidden_dim, device, dtype): |
| 29 | + train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype) |
| 30 | + train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim) |
| 31 | + train_dataset = torch.utils.data.TensorDataset(train_data, train_label) |
| 32 | + return train_dataset |
| 33 | + |
| 34 | + |
| 35 | +def random_dataloader(model, total_samples, hidden_dim, device, dtype): |
| 36 | + batch_size = model.train_micro_batch_size_per_gpu() |
| 37 | + train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype) |
| 38 | + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) |
| 39 | + return train_loader |
| 40 | + |
| 41 | + |
| 42 | +def run_model(model, config_dict, hidden_dim, dtype, pin_memory, topk_ratio, update_interval, overlap_step, iteration): |
| 43 | + |
| 44 | + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) |
| 45 | + |
| 46 | + |
| 47 | + data_loader = random_dataloader(model=model, |
| 48 | + total_samples=iteration, |
| 49 | + hidden_dim=hidden_dim, |
| 50 | + device=model.device, |
| 51 | + dtype=dtype) |
| 52 | + |
| 53 | + time_step_list = [] |
| 54 | + accumulation_step_time_list = [] |
| 55 | + update_step_time_list = [] |
| 56 | + |
| 57 | + dist.barrier() |
| 58 | + for i, batch in enumerate(data_loader): |
| 59 | + step_start_time = time.time() |
| 60 | + loss = model(batch[0], batch[1]) |
| 61 | + model.backward(loss) |
| 62 | + model.step() |
| 63 | + step_end_time = time.time() |
| 64 | + step_time = step_end_time - step_start_time |
| 65 | + if dist.get_rank() == 0: |
| 66 | + print(f"Step {i} time: {step_time*1000:.2f}ms") |
| 67 | + if i >= update_interval: |
| 68 | + time_step_list.append(step_time) |
| 69 | + if (i + 1) % update_interval == 0: |
| 70 | + update_step_time_list.append(step_time) |
| 71 | + else: |
| 72 | + accumulation_step_time_list.append(step_time) |
| 73 | + |
| 74 | + if dist.get_rank() == 0: |
| 75 | + with open("zenflow_report.log", "a") as f: |
| 76 | + msg = f"{1 if pin_memory else 0}," \ |
| 77 | + f"{topk_ratio}," \ |
| 78 | + f"{update_interval}," \ |
| 79 | + f"{overlap_step}," \ |
| 80 | + f"{sum(accumulation_step_time_list) / len(accumulation_step_time_list):.2f}," \ |
| 81 | + f"{sum(update_step_time_list) / len(update_step_time_list):.2f}" |
| 82 | + f.write(f"{msg}\n") |
| 83 | + print(f"[Summary] pin_memory={pin_memory} topk_ratio={topk_ratio} update_interval={update_interval} overlap_step={overlap_step} avg_accumulation_step={sum(accumulation_step_time_list) * 1000 / len(accumulation_step_time_list):.2f}ms avg_update_step={sum(update_step_time_list) * 1000 / len(update_step_time_list):.2f}ms") |
| 84 | + |
| 85 | + model.destroy() |
| 86 | + |
| 87 | +def main(): |
| 88 | + parser = argparse.ArgumentParser() |
| 89 | + parser.add_argument("--nlayers", type=int, default=1) |
| 90 | + parser.add_argument("--hidden_dim", type=int, default=1024) |
| 91 | + parser.add_argument("--dtype", choices=['torch.bfloat16', 'torch.float16', 'torch.float32'], default='torch.bfloat16') |
| 92 | + parser.add_argument("--iteration", type=int, default=5) |
| 93 | + parser.add_argument("--local_rank", type=int, default=-1) |
| 94 | + |
| 95 | + parser.add_argument("--pin_memory_opts", type=int, required=True) |
| 96 | + parser.add_argument("--topk_ratios", type=float, required=True) |
| 97 | + parser.add_argument("--update_intervals", type=int, required=True) |
| 98 | + parser.add_argument("--overlap_steps", type=int, required=True) |
| 99 | + |
| 100 | + # Optional: explicitly receive master_port (though deepspeed handles it via env) |
| 101 | + parser.add_argument("--master_port", type=int, default=None) |
| 102 | + |
| 103 | + args = parser.parse_args() |
| 104 | + dtype = eval(args.dtype) |
| 105 | + |
| 106 | + |
| 107 | + pin_memory = bool(args.pin_memory_opts) |
| 108 | + topk_ratio = args.topk_ratios |
| 109 | + update_interval = args.update_intervals |
| 110 | + overlap_step = bool(args.overlap_steps) |
| 111 | + total_iteration = args.iteration * update_interval |
| 112 | + |
| 113 | + config_dict = { |
| 114 | + "train_micro_batch_size_per_gpu": 1, |
| 115 | + "optimizer": { |
| 116 | + "type": "Adam", |
| 117 | + "params": { |
| 118 | + "lr": 1e-6 |
| 119 | + } |
| 120 | + }, |
| 121 | + "zero_optimization": { |
| 122 | + "stage": 2, |
| 123 | + "offload_optimizer": { |
| 124 | + "device": "cpu", |
| 125 | + "pin_memory": pin_memory |
| 126 | + }, |
| 127 | + "zenflow": { |
| 128 | + "topk_ratio": topk_ratio, |
| 129 | + "update_interval": update_interval, |
| 130 | + "full_warm_up_rounds": 0, |
| 131 | + "overlap_step": overlap_step |
| 132 | + }, |
| 133 | + }, |
| 134 | + "wall_clock_breakdown": True, |
| 135 | + "zero_allow_untested_optimizer": True |
| 136 | + } |
| 137 | + |
| 138 | + if dtype == torch.float16: |
| 139 | + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} |
| 140 | + elif dtype == torch.bfloat16: |
| 141 | + config_dict["bf16"] = {"enabled": True} |
| 142 | + |
| 143 | + model = SimpleModel(args.hidden_dim, nlayers=args.nlayers) |
| 144 | + run_model(model, config_dict, args.hidden_dim, dtype, |
| 145 | + pin_memory, topk_ratio, update_interval, overlap_step, |
| 146 | + total_iteration) |
| 147 | + |
| 148 | + |
| 149 | +if __name__ == "__main__": |
| 150 | + main() |
0 commit comments