-
Notifications
You must be signed in to change notification settings - Fork 154
Feat/converter-tensorboard #548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
b7dedd6
init
Zeyi-Lin d88ca67
total logic
Zeyi-Lin a459f07
Update tfb_converter.py
Zeyi-Lin 7a57487
add click
Zeyi-Lin 4e76902
Update main.py
Zeyi-Lin d6f4a08
Merge branch 'main' into feat/converter-tensorboard
Zeyi-Lin 34d52bf
use swl
Zeyi-Lin 00e1af7
use click argument
Zeyi-Lin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .tfb import TFBConverter |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .tfb_converter import TFBConverter |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import os | ||
from PIL import Image | ||
import io | ||
import numpy as np | ||
|
||
try: | ||
import tensorflow as tf | ||
except ImportError as e: | ||
raise TypeError( | ||
"Tensorboard Converter requires tensorflow when process tfevents file. Install with 'pip install tensorflow'." | ||
) | ||
|
||
|
||
def get_tf_events_tags_type(tf_event_path: str): | ||
"""获取TFEvent文件中所有tag的类型,并返回一个字典 | ||
比如{"tag1": "scalar", "tag2": "image", "tag3": "audio", "tag4": "text"} | ||
|
||
Args: | ||
tf_event_path (_type_): 单个tf_event文件的路径 | ||
|
||
Returns: | ||
_type_: 返回一个字典,键是tag,值是tag的类型 | ||
""" | ||
# 确保路径存在 | ||
assert os.path.exists(tf_event_path), "TFEvent file does not exist" | ||
|
||
# 用来存储所有tag的集合和类型 | ||
tags = {} | ||
|
||
# 遍历所有事件 | ||
for event in tf.compat.v1.train.summary_iterator(tf_event_path): | ||
for value in event.summary.value: | ||
if value.tag not in tags: | ||
if value.HasField("simple_value"): | ||
tags[value.tag] = "scalar" | ||
elif value.HasField("image"): | ||
tags[value.tag] = "image" | ||
elif value.HasField("audio"): | ||
tags[value.tag] = "audio" | ||
elif value.HasField("tensor") and value.tensor.dtype == tf.string: | ||
tags[value.tag] = "text" | ||
|
||
return tags | ||
|
||
|
||
def get_tf_events_tags_data(tf_event_path: str, tags: dict): | ||
"""获取TFEvent文件中所有tag的数据,并返回一个字典 | ||
比如{"tag1": [(step1, value1), (step2, value2)], "tag2": [(step1, value1), (step2, value2)]} | ||
|
||
Args: | ||
tf_event_path (str): tf_event文件的路径 | ||
tags (dict): tag的类型字典, 形如{"loss": "scalar", "np_image": "image"} | ||
|
||
Returns: | ||
_type_: 返回一个字典,键是tag,值是完整的数据, 形如{"loss": [(step1, value1), (step2, value2)], "np_image": [(step1, value1), (step2, value2)]} | ||
""" | ||
|
||
# 用来存储每个tag的数据 | ||
tag_data = {tag: [] for tag in tags} | ||
|
||
# 再次遍历文件,这次是为了提取数据 | ||
for event in tf.compat.v1.train.summary_iterator(tf_event_path): | ||
# wall_time = datetime.datetime.fromtimestamp(event.wall_time).strftime("%Y-%m-%d %H:%M:%S") | ||
wall_time = int(event.wall_time) | ||
for value in event.summary.value: | ||
if value.tag in tag_data: | ||
if tags[value.tag] == "scalar": | ||
tag_data[value.tag].append((event.step, value.simple_value, wall_time)) | ||
elif tags[value.tag] == "image" and value.HasField("image"): | ||
img_str = value.image.encoded_image_string | ||
image = Image.open(io.BytesIO(img_str)) | ||
tag_data[value.tag].append((event.step, image, wall_time)) | ||
elif tags[value.tag] == "audio" and value.HasField("audio"): | ||
audio = value.audio | ||
audio_np = np.frombuffer(value.audio.encoded_audio_string, dtype=np.int16) | ||
sample_rate = audio.sample_rate | ||
|
||
tag_data[value.tag].append((event.step, [audio_np, int(sample_rate)], wall_time)) | ||
elif tags[value.tag] == "text" and value.HasField("tensor"): | ||
text = tf.make_ndarray(value.tensor).item().decode("utf-8") | ||
tag_data[value.tag].append((event.step, text, wall_time)) | ||
|
||
return tag_data | ||
|
||
|
||
def find_tfevents(logdir: str, depth: int = 3): | ||
"""查找指定目录下的所有tfevents文件 | ||
|
||
Args: | ||
logdir (str): 日志文件夹路径 | ||
depth (int, optional): 目录深度,默认为3 | ||
|
||
Returns: | ||
Dict: 返回一个字典,键是子目录名,值是该目录下的所有tfevents文件路径列表 | ||
""" | ||
tfevents_dict = {} | ||
|
||
def get_depth(path): | ||
return path.count(os.sep) | ||
|
||
base_depth = get_depth(logdir) | ||
|
||
for root, dirs, files in os.walk(logdir): | ||
current_depth = get_depth(root) - base_depth | ||
if current_depth > depth: | ||
continue | ||
|
||
for file in files: | ||
if "tfevents" in file: | ||
# 使用目录名作为字典的键 | ||
directory_key = os.path.relpath(root, logdir) # 获取目录的最后一个部分作为键 | ||
if directory_key not in tfevents_dict: | ||
tfevents_dict[directory_key] = [] | ||
tfevents_dict[directory_key].append(os.path.join(root, file)) | ||
|
||
return tfevents_dict |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
""" | ||
ISSUE: https://github.com/SwanHubX/SwanLab/issues/437 | ||
""" | ||
|
||
import os | ||
import swanlab | ||
from datetime import datetime | ||
from ._utils import find_tfevents, get_tf_events_tags_type, get_tf_events_tags_data | ||
from swanlab.log import swanlog as swl | ||
|
||
|
||
class TFBConverter: | ||
def __init__( | ||
self, | ||
convert_dir: str, | ||
project: str = None, | ||
workspace: str = None, | ||
config: dict = None, | ||
cloud: bool = True, | ||
logdir: str = None, | ||
**kwargs, | ||
): | ||
self.convert_dir = convert_dir | ||
self.project = project | ||
self.workspace = workspace | ||
self.cloud = cloud | ||
self.config = config | ||
self.logdir = logdir | ||
|
||
def run(self, depth=3): | ||
swl.info("Start converting TFEvent files to SwanLab format...") | ||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | ||
|
||
# 找到所有TFEvent文件, 生成一个路径字典 | ||
path_dict = find_tfevents(self.convert_dir, depth=depth) | ||
if path_dict: | ||
swl.info("Found TFEvent file path dictionary.") | ||
else: | ||
swl.error(f"No TFEvent file found in {self.convert_dir}, please check the path.") | ||
return | ||
|
||
for dir, paths in path_dict.items(): | ||
for path in paths: | ||
filename = os.path.basename(path) | ||
|
||
""" | ||
获取所有的tag与其对应的类型, example: | ||
type_by_tags = {'training/loss': 'scalar', 'fake image': 'image', 'fake audio': 'audio', 'fake text/text_summary': 'text'} | ||
""" | ||
type_by_tags = get_tf_events_tags_type(path) | ||
|
||
# 如果有tag(即该日志文件有记录指标,而非空文件) | ||
if type_by_tags: | ||
# 初始化一个SwanLab实验 | ||
run = swanlab.init( | ||
project=(f"Tensorboard-Conversion-{timestamp}" if self.project is None else self.project), | ||
experiment_name=f"{dir}/{filename}", | ||
workspace=self.workspace, | ||
config={"tfevent_path": path}, | ||
cloud=self.cloud, | ||
logdir=self.logdir, | ||
) | ||
|
||
if self.config: | ||
run.config.update(self.config) | ||
|
||
""" | ||
根据tag提取数据, 格式为{tag: [(step, value, wall_time), ...]}, example: | ||
data_by_tags = { | ||
'training_loss': [ | ||
(0, 0.0, 1715839693), | ||
(1, 0.019999999552965164, 1715839711), | ||
(2, 0.03999999910593033, 1715839717) | ||
], | ||
... | ||
} | ||
""" | ||
data_by_tags = get_tf_events_tags_data(path, type_by_tags) | ||
|
||
times = [] | ||
# 遍历数据 | ||
if data_by_tags: | ||
# 打印并转换数据到SwanLab | ||
for tag, data in data_by_tags.items(): | ||
for step, value, time in data: | ||
times.append(time) | ||
# 如果是标量 | ||
if type_by_tags[tag] == "scalar": | ||
swanlab.log({tag: value}, step=step) | ||
# 如果是图片 | ||
elif type_by_tags[tag] == "image": | ||
swanlab.log({tag: swanlab.Image(value)}, step=step) | ||
# 如果是音频 | ||
elif type_by_tags[tag] == "audio": | ||
swanlab.log({tag: swanlab.Audio(value[0], sample_rate=value[1])}, step=step) | ||
# 如果是文本 | ||
elif type_by_tags[tag] == "text": | ||
swanlab.log({tag: swanlab.Text(value)}, step=step) | ||
# TODO: 随着SwanLab的发展,支持转换更多类型 | ||
|
||
# 计算完整的运行时间 | ||
runtime = max(times) - min(times) | ||
swanlab.config.update({"RunTime": runtime}) | ||
|
||
# 结束当前实验 | ||
run.finish() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.