Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from utils.ds_utils import get_train_ds_config
from utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
from utils.model.model_utils import create_hf_model
from utils.perf import print_throughput


def parse_args():
Expand Down Expand Up @@ -321,7 +322,9 @@ def evaluation(model, eval_dataloader):
f"Beginning of Epoch {epoch+1}/{args.num_train_epochs}, Total Micro Batches {len(train_dataloader)}",
args.global_rank)
model.train()
import time
for step, batch in enumerate(train_dataloader):
start = time.time()
batch = to_device(batch, device)
outputs = model(**batch, use_cache=False)
loss = outputs.loss
Expand All @@ -331,6 +334,10 @@ def evaluation(model, eval_dataloader):
)
model.backward(loss)
model.step()
end = time.time()
if torch.distributed.get_rank() == 0:
print_throughput(model.model, args, end - start,
args.global_rank)

# Evaluate perplexity on the validation set.
print_rank_0(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import argparse
import os
import random
import time
import torch
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
Expand All @@ -42,6 +43,7 @@
from utils.data.data_utils import create_prompt_dataset, MiniDataset, DataCollatorRLHF, get_unsupervised_data
from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, moving_average, save_zero_three_model, load_hf_tokenizer
from utils.module.lora import convert_lora_to_linear_layer
from utils.perf import print_throughput_step3

writer = None

Expand Down Expand Up @@ -478,13 +480,9 @@ def main():
args.global_rank)
for step, (batch_prompt, batch_unsupervised) in enumerate(
zip(prompt_train_dataloader, unsupervised_train_dataloader)):

batch_prompt = to_device(batch_prompt, device)
if batch_unsupervised is not None:
batch_unsupervised = to_device(batch_unsupervised, device)
unsup_dataset = unsup_mini_dataset.add(batch_unsupervised)
else:
unsup_dataset = unsup_mini_dataset.add(
[[None] * args.per_device_generation_batch_size])

# prompts = batch_prompt['prompt']
# length = prompts.size(-1)
# if length > args.max_prompt_seq_len:
Expand All @@ -494,6 +492,15 @@ def main():
out = trainer.generate_experience(batch_prompt['prompt'],
batch_prompt['prompt_att_mask'],
step)

training_start = time.time()
if batch_unsupervised is not None:
batch_unsupervised = to_device(batch_unsupervised, device)
unsup_dataset = unsup_mini_dataset.add(batch_unsupervised)
else:
unsup_dataset = unsup_mini_dataset.add(
[[None] * args.per_device_generation_batch_size])

exp_dataset = exp_mini_dataset.add(out)

if exp_dataset is not None:
Expand Down Expand Up @@ -526,16 +533,24 @@ def main():
random.shuffle(exp_dataset)
random.shuffle(unsup_dataset)

end = time.time()
training_time = end - training_start
e2e_time = training_time + trainer.generate_time * args.generation_batches # it is an approximation, we did not include, e.g., rw forward time etc

print_rank_0(
f'epoch: {epoch}|step: {step}|ppo_ep: {ppo_ep+1}|act_loss: {actor_loss_sum/inner_iter}|cri_loss: {critic_loss_sum/inner_iter}|unsuper_loss: {unsup_loss_sum/inner_iter}',
f'Epoch: {epoch} | Step: {step} | PPO Epoch: {ppo_ep+1} | Actor Loss: {actor_loss_sum/inner_iter} | Critic Loss: {critic_loss_sum/inner_iter} | Unsupervised Loss: {unsup_loss_sum/inner_iter}',
args.global_rank)
print_throughput_step3(rlhf_engine.actor.model, args, e2e_time,
trainer.generate_time, training_time,
args.global_rank)
average_reward = get_all_reduce_mean(average_reward).item()
print_rank_0(
f"average reward score: {average_reward/inner_iter}",
f"Average reward score: {average_reward/inner_iter}",
args.global_rank)
print_rank_0(
"-------------------------------------------------------------------------------------",
args.global_rank)

if args.enable_tensorboard and torch.distributed.get_rank(
) == 0:
writer.add_scalar('reward',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn.functional as F
import sys
import os
import time
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus

Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(self, rlhf_engine, args):
self.cliprange_value = 0.2
self.gamma = 1.0
self.lam = 0.95
self.generate_time = 0.0

def _generate_sequence(self, prompts, mask, step):

Expand Down Expand Up @@ -116,7 +118,9 @@ def _generate_sequence(self, prompts, mask, step):

def generate_experience(self, prompts, mask, step):
self.eval()
generate_start = time.time()
seq = self._generate_sequence(prompts, mask, step)
generate_end = time.time()
self.train()

pad_token_id = self.tokenizer.pad_token_id
Expand All @@ -134,6 +138,8 @@ def generate_experience(self, prompts, mask, step):
logits = output.logits
logits_ref = output_ref.logits

self.generate_time = generate_end - generate_start

return {
'prompts': prompts,
'logprobs': gather_log_probs(logits[:, :-1, :], seq[:, 1:]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ deepspeed --master_port 12346 main.py \
--critic_zero_stage $CRITIC_ZERO_STAGE \
--enable_ema \
--output_dir $OUTPUT \
--print_answers \
--enable_tensorboard \
--tensorboard_path $OUTPUT \
&> $OUTPUT/training.log
126 changes: 126 additions & 0 deletions applications/DeepSpeed-Chat/training/utils/perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch


# This function can be used to print throughput for Step 1 and 2 only
def print_throughput(hf_model, args, e2e_time, rank=0):
if rank <= 0:
hf_config = hf_model.config
num_layers = getattr(hf_config, "num_hidden_layers",
getattr(hf_config, "n_layer", None))
hidden_size = getattr(hf_config, "hidden_size",
getattr(hf_config, "n_embd", None))
vocab_size = getattr(hf_config, "vocab_size", None)
assert all(
(num_layers, hidden_size, vocab_size)
), "Could not determine number of layers, hidden size, and vocab size of the model"

gpus_per_model = torch.distributed.get_world_size()
seq_length = args.max_seq_len
batch_size = args.per_device_train_batch_size
samples_per_second = batch_size / e2e_time
checkpoint_activations_factor = 4 if args.gradient_checkpointing else 3
hf_model._num_params = sum([
p.ds_numel if hasattr(p, "ds_tensor") else p.numel()
for p in hf_model.parameters()
])
params_in_billions = hf_model._num_params / (1e9)

# Megatron paper's formula to calculate training flops
train_flops_per_iteration = (
24 * checkpoint_activations_factor * batch_size * seq_length *
num_layers *
(hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) +
(vocab_size /
(16.0 * num_layers * hidden_size)))

train_tflops = train_flops_per_iteration / (e2e_time * gpus_per_model *
(10**12))

param_string = f"{params_in_billions:.3f} B" if params_in_billions != 0 else "NA"
print(
f"Model Parameters: {param_string}, Latency: {e2e_time:.2f}s, TFLOPs: {train_tflops:.2f}, Samples/sec: {samples_per_second:.2f}, Time/seq {e2e_time/batch_size:.2f}s, Batch Size: {batch_size}, Sequence Length: {seq_length}"
)


# Enhanced version of the function above that provides calculations and printing for Step 3
def print_throughput_step3(hf_model,
args,
e2e_time,
gen_exp_time,
train_time,
rank=0):
if rank <= 0:
hf_config = hf_model.config
num_layers = getattr(hf_config, "num_hidden_layers",
getattr(hf_config, "n_layer", None))
hidden_size = getattr(hf_config, "hidden_size",
getattr(hf_config, "n_embd", None))
vocab_size = getattr(hf_config, "vocab_size", None)
assert all(
(num_layers, hidden_size, vocab_size)
), "Could not determine number of layers, hidden size, and vocab size of the model"

gpus_per_model = torch.distributed.get_world_size()
seq_length = args.max_answer_seq_len + args.max_prompt_seq_len
batch_size = args.per_device_generation_batch_size * args.generation_batches * args.ppo_epochs * gpus_per_model * 1 if args.unsupervised_dataset_name is None else 2
samples_per_second = batch_size / e2e_time
checkpoint_activations_factor = 4 if args.actor_gradient_checkpointing else 3
hf_model._num_params = sum([
p.ds_numel if hasattr(p, "ds_tensor") else p.numel()
for p in hf_model.parameters()
])
params_in_billions = hf_model._num_params / (1e9)

# Megatron paper's formula to calculate training flops
train_flops_per_iteration = (
24 * checkpoint_activations_factor * batch_size * seq_length *
num_layers *
(hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) +
(vocab_size /
(16.0 * num_layers * hidden_size)))

train_tflops = train_flops_per_iteration / (train_time *
gpus_per_model * (10**12))

gen_bs = args.per_device_generation_batch_size * gpus_per_model

# Modified formula for calculating flops in forward pass only
gen_flops_per_iteration = (
24 * gen_bs * seq_length * num_layers *
(hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) +
(vocab_size /
(16.0 * num_layers * hidden_size)))

gen_tflops = gen_flops_per_iteration / (gen_exp_time * gpus_per_model *
(10**12))

if hf_config.torch_dtype == "float16":
num_bytes = 2
elif hf_config.torch_dtype == "float32":
num_bytes = 4
else:
num_bytes = 1

gen_bw = (hf_model._num_params *
(num_bytes / 1e9)) / gen_exp_time * args.max_answer_seq_len

total_flops_per_iteration = train_flops_per_iteration + gen_flops_per_iteration * args.generation_batches
total_tflops = total_flops_per_iteration / (e2e_time * gpus_per_model *
(10**12))

print(
f"End-to-End => Latency: {e2e_time:.2f}s, TFLOPs: {total_tflops:.2f}, Samples/sec: {samples_per_second:.2f}, Time/seq {e2e_time/batch_size:.2f}s, Batch Size: {batch_size}, Sequence Length: {seq_length}"
)
print(
f"Generation => Latency: {gen_exp_time:.2f}s, TFLOPs: {gen_tflops:.2f}, BW: {gen_bw:.2f} GB/sec"
)
print(
f"Training => Latency: {train_time:.2f}s, TFLOPs: {train_tflops:.2f}"
)
param_string = f"{params_in_billions:.3f} B" if params_in_billions != 0 else "NA"
print(f"Parameters => {param_string}")