17
17
import time
18
18
from collections .abc import Generator
19
19
20
- from transformers import Trainer , is_wandb_available
20
+ from transformers import Trainer
21
+ from transformers .integrations import is_mlflow_available , is_wandb_available
21
22
22
23
23
24
if is_wandb_available ():
24
25
import wandb
25
26
27
+ if is_mlflow_available ():
28
+ import mlflow
29
+
26
30
27
31
@contextlib .contextmanager
28
32
def profiling_context (trainer : Trainer , name : str ) -> Generator [None , None , None ]:
29
33
"""
30
- A context manager function for profiling a block of code. Results are logged to Weights & Biases if enabled.
34
+ A context manager function for profiling a block of code. Results are logged to Weights & Biases or MLflow
35
+ depending on the trainer's configuration.
31
36
32
37
Args:
33
38
trainer (`~transformers.Trainer`):
@@ -54,8 +59,12 @@ def some_method(self):
54
59
end_time = time .perf_counter ()
55
60
duration = end_time - start_time
56
61
62
+ profiling_metrics = {f"profiling/Time taken: { trainer .__class__ .__name__ } .{ name } " : duration }
57
63
if "wandb" in trainer .args .report_to and wandb .run is not None and trainer .accelerator .is_main_process :
58
- wandb .log ({f"profiling/Time taken: { trainer .__class__ .__name__ } .{ name } " : duration })
64
+ wandb .log (profiling_metrics )
65
+
66
+ if "mlflow" in trainer .args .report_to and mlflow .run is not None and trainer .accelerator .is_main_process :
67
+ mlflow .log_metrics (profiling_metrics , step = trainer .state .global_step )
59
68
60
69
61
70
def profiling_decorator (func : callable ) -> callable :
0 commit comments