-
Notifications
You must be signed in to change notification settings - Fork 1.1k
add timers and performance metrics #688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
7a82c7a
add timers
awan-10 684d2aa
Move import time and precommit
molly-smith 7b28f23
Match generate time to HE eval
molly-smith c7b3ac0
add flops counter/printer.
awan-10 a069ec5
Modify/clean up tflops func
molly-smith 4a22383
improve logging.
awan-10 3b2f9e3
Merge branch 'master' into amawa/add-timers-flops
awan-10 41fe15f
fix name.
awan-10 1b25121
fix format.
awan-10 3a69cb2
undo debugging/
awan-10 f5a4dc4
take Lev's feedback.
awan-10 d5fbde1
fix performance calculations.
awan-10 94c249f
Merge branch 'master' into amawa/add-timers-flops
lekurile d418550
update flops calculation (#702)
yaozhewei 7582c33
remove unused timer.
awan-10 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| # This function can be used to print throughput for Step 1 and 2 only | ||
| def print_throughput(hf_model, args, e2e_time, rank=0): | ||
| if rank <= 0: | ||
| hf_config = hf_model.config | ||
| num_layers = getattr(hf_config, "num_hidden_layers", | ||
| getattr(hf_config, "n_layer", None)) | ||
| hidden_size = getattr(hf_config, "hidden_size", | ||
| getattr(hf_config, "n_embd", None)) | ||
| vocab_size = getattr(hf_config, "vocab_size", None) | ||
| assert all( | ||
| (num_layers, hidden_size, vocab_size) | ||
| ), "Could not determine number of layers, hidden size, and vocab size of the model" | ||
|
|
||
| gpus_per_model = torch.distributed.get_world_size() | ||
| seq_length = args.max_seq_len | ||
| batch_size = args.per_device_train_batch_size | ||
| samples_per_second = batch_size / e2e_time | ||
| checkpoint_activations_factor = 4 if args.gradient_checkpointing else 3 | ||
| hf_model._num_params = sum([ | ||
| p.ds_numel if hasattr(p, "ds_tensor") else p.numel() | ||
| for p in hf_model.parameters() | ||
| ]) | ||
| params_in_billions = hf_model._num_params / (1e9) | ||
|
|
||
| # Megatron paper's formula to calculate training flops | ||
| train_flops_per_iteration = ( | ||
| 24 * checkpoint_activations_factor * batch_size * seq_length * | ||
| num_layers * | ||
| (hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) + | ||
| (vocab_size / | ||
| (16.0 * num_layers * hidden_size))) | ||
|
|
||
| train_tflops = train_flops_per_iteration / (e2e_time * gpus_per_model * | ||
| (10**12)) | ||
|
|
||
| param_string = f"{params_in_billions:.3f} B" if params_in_billions != 0 else "NA" | ||
| print( | ||
| f"Model Parameters: {param_string}, Latency: {e2e_time:.2f}s, TFLOPs: {train_tflops:.2f}, Samples/sec: {samples_per_second:.2f}, Time/seq {e2e_time/batch_size:.2f}s, Batch Size: {batch_size}, Sequence Length: {seq_length}" | ||
| ) | ||
|
|
||
|
|
||
| # Enhanced version of the function above that provides calculations and printing for Step 3 | ||
| def print_throughput_step3(hf_model, | ||
| args, | ||
| e2e_time, | ||
| gen_exp_time, | ||
| train_time, | ||
| rank=0): | ||
| if rank <= 0: | ||
| hf_config = hf_model.config | ||
| num_layers = getattr(hf_config, "num_hidden_layers", | ||
| getattr(hf_config, "n_layer", None)) | ||
| hidden_size = getattr(hf_config, "hidden_size", | ||
| getattr(hf_config, "n_embd", None)) | ||
| vocab_size = getattr(hf_config, "vocab_size", None) | ||
| assert all( | ||
| (num_layers, hidden_size, vocab_size) | ||
| ), "Could not determine number of layers, hidden size, and vocab size of the model" | ||
|
|
||
| gpus_per_model = torch.distributed.get_world_size() | ||
| seq_length = args.max_answer_seq_len + args.max_prompt_seq_len | ||
| batch_size = args.per_device_generation_batch_size * args.generation_batches * args.ppo_epochs * gpus_per_model * 1 if args.unsupervised_dataset_name is None else 2 | ||
| samples_per_second = batch_size / e2e_time | ||
| checkpoint_activations_factor = 4 if args.actor_gradient_checkpointing else 3 | ||
| hf_model._num_params = sum([ | ||
| p.ds_numel if hasattr(p, "ds_tensor") else p.numel() | ||
| for p in hf_model.parameters() | ||
| ]) | ||
| params_in_billions = hf_model._num_params / (1e9) | ||
|
|
||
| # Megatron paper's formula to calculate training flops | ||
| train_flops_per_iteration = ( | ||
| 24 * checkpoint_activations_factor * batch_size * seq_length * | ||
| num_layers * | ||
| (hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) + | ||
| (vocab_size / | ||
| (16.0 * num_layers * hidden_size))) | ||
|
|
||
| train_tflops = train_flops_per_iteration / (train_time * | ||
| gpus_per_model * (10**12)) | ||
|
|
||
| gen_bs = args.per_device_generation_batch_size * gpus_per_model | ||
|
|
||
| # Modified formula for calculating flops in forward pass only | ||
| gen_flops_per_iteration = ( | ||
| 24 * gen_bs * seq_length * num_layers * | ||
| (hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) + | ||
| (vocab_size / | ||
| (16.0 * num_layers * hidden_size))) | ||
|
|
||
| gen_tflops = gen_flops_per_iteration / (gen_exp_time * gpus_per_model * | ||
| (10**12)) | ||
|
|
||
| if hf_config.torch_dtype == "float16": | ||
| num_bytes = 2 | ||
| elif hf_config.torch_dtype == "float32": | ||
| num_bytes = 4 | ||
| else: | ||
| num_bytes = 1 | ||
|
|
||
| gen_bw = (hf_model._num_params * | ||
| (num_bytes / 1e9)) / gen_exp_time * args.max_answer_seq_len | ||
|
|
||
| total_flops_per_iteration = train_flops_per_iteration + gen_flops_per_iteration * args.generation_batches | ||
| total_tflops = total_flops_per_iteration / (e2e_time * gpus_per_model * | ||
| (10**12)) | ||
|
|
||
| print( | ||
| f"End-to-End => Latency: {e2e_time:.2f}s, TFLOPs: {total_tflops:.2f}, Samples/sec: {samples_per_second:.2f}, Time/seq {e2e_time/batch_size:.2f}s, Batch Size: {batch_size}, Sequence Length: {seq_length}" | ||
| ) | ||
| print( | ||
| f"Generation => Latency: {gen_exp_time:.2f}s, TFLOPs: {gen_tflops:.2f}, BW: {gen_bw:.2f} GB/sec" | ||
| ) | ||
| print( | ||
| f"Training => Latency: {train_time:.2f}s, TFLOPs: {train_tflops:.2f}" | ||
| ) | ||
| param_string = f"{params_in_billions:.3f} B" if params_in_billions != 0 else "NA" | ||
| print(f"Parameters => {param_string}") |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.