Skip to content

Commit 605284a

Browse files
Merge pull request #859 from Puiching-Memory/Profiler
将torch.profiler保存至本地
1 parent 05caca2 commit 605284a

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

swanlab/log/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
@IDE: vscode
77
@Description:
88
日志记录模块,在设计上swanlog作为一个独立的模块被使用
9-
FIXME: shit code
109
"""
11-
1210
from .log import SwanLog
1311

1412
swanlog: SwanLog = SwanLog("swanlab")

swanlab/log/profiler.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
@author: Puiching-Memory
3+
@file: profiler.py
4+
@time: 2025/6/8 17:47
5+
@description: 保存模型 profiler 日志
6+
"""
7+
8+
9+
def trace_handler(save_dir: str):
10+
"""
11+
trace_handler 是一个回调函数,用于处理 torch.profiler 的 trace 信息,并将其保存到文件中
12+
13+
examples
14+
-------
15+
>>> activities = [torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]
16+
>>> with torch.profiler.profile(activities=activities,on_trace_ready=trace_handler()) as p:
17+
"""
18+
from . import swanlog
19+
import os
20+
21+
assert os.path.isdir(save_dir), RuntimeError(
22+
"Run directory not found. Please ensure the run directory is properly set."
23+
)
24+
25+
def handler_fn(prof) -> None:
26+
saved_path = os.path.join(save_dir, 'trace.json')
27+
if os.path.exists(saved_path):
28+
swanlog.warning(f"{saved_path} already exists, will be overwritten")
29+
os.remove(f"{saved_path}")
30+
else:
31+
swanlog.info(f"torch.profiler trace is saved to {saved_path}")
32+
33+
prof.export_chrome_trace(f"{saved_path}")
34+
35+
return handler_fn

0 commit comments

Comments
 (0)