Skip to content

Commit 280d353

Browse files
🌊 Add MLflow metrics in profiling context (#3400)
Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 13fa840 commit 280d353

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

trl/extras/profiling.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,22 @@
1717
import time
1818
from collections.abc import Generator
1919

20-
from transformers import Trainer, is_wandb_available
20+
from transformers import Trainer
21+
from transformers.integrations import is_mlflow_available, is_wandb_available
2122

2223

2324
if is_wandb_available():
2425
import wandb
2526

27+
if is_mlflow_available():
28+
import mlflow
29+
2630

2731
@contextlib.contextmanager
2832
def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None]:
2933
"""
30-
A context manager function for profiling a block of code. Results are logged to Weights & Biases if enabled.
34+
A context manager function for profiling a block of code. Results are logged to Weights & Biases or MLflow
35+
depending on the trainer's configuration.
3136
3237
Args:
3338
trainer (`~transformers.Trainer`):
@@ -54,8 +59,12 @@ def some_method(self):
5459
end_time = time.perf_counter()
5560
duration = end_time - start_time
5661

62+
profiling_metrics = {f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration}
5763
if "wandb" in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process:
58-
wandb.log({f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration})
64+
wandb.log(profiling_metrics)
65+
66+
if "mlflow" in trainer.args.report_to and mlflow.run is not None and trainer.accelerator.is_main_process:
67+
mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step)
5968

6069

6170
def profiling_decorator(func: callable) -> callable:

0 commit comments

Comments
 (0)