Skip to content

🪪 Adds profiling decorators for GRPOTrainer #2889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions trl/extras/profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import time

from transformers import is_wandb_available


if is_wandb_available():
import wandb


def profiling_decorator(func):
"""
Decorator to profile a function and log the time taken to execute it.
"""

@functools.wraps(func)
def wrapper(self, *args, **kwargs):
start_time = time.perf_counter()
result = func(self, *args, **kwargs)
end_time = time.perf_counter()
duration = end_time - start_time

if "wandb" in self.args.report_to and wandb.run is not None and self.accelerator.is_main_process:
wandb.log({f"profiling/Time taken: {self.__class__.__name__}.{func.__name__}": duration})
return result

return wrapper
5 changes: 5 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from transformers.utils import is_peft_available

from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..extras.profiling import profiling_decorator
from ..import_utils import is_vllm_available
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from .callbacks import SyncRefModelCallback
Expand Down Expand Up @@ -476,6 +477,7 @@ def _get_eval_sampler(self, eval_dataset) -> Sampler:
return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)

# Get the per-token log probabilities for the completions for the model and the reference model
@profiling_decorator
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
Expand All @@ -487,6 +489,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
logits = logits[:, -logits_to_keep:]
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens

@profiling_decorator
def _move_model_to_vllm(self):
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
Expand Down Expand Up @@ -518,6 +521,7 @@ def _move_model_to_vllm(self):
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()

@profiling_decorator
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
Expand Down Expand Up @@ -699,6 +703,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
"advantages": advantages,
}

@profiling_decorator
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
Expand Down