Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions swanlab/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,62 @@ def login(api_key: str, relogin: bool, **kwargs):
login_info = terminal_login(api_key)
print(FONT.swanlab("Login successfully. Hi, " + FONT.bold(FONT.default(login_info.username))) + "!")

if __name__ == "__main__":
cli()

# ---------------------------------- 转换命令,用于转换其他实验跟踪工具 ----------------------------------
@cli.command()
@click.argument(
"convert_dir",
type=str,
)
@click.option(
"--type",
"-t",
default="tensorboard",
type=click.Choice(["tensorboard"]),
help="The type of the experiment tracking tool you want to convert to.",
)
@click.option(
"--project",
"-p",
default=None,
type=str,
help="SwanLab project name.",
)
@click.option(
"--workspace",
"-w",
default=None,
type=str,
help="swanlab.init workspace parameter.",
)
@click.option(
"--cloud",
default=True,
type=bool,
help="swanlab.init cloud parameter.",
)
@click.option(
"--logdir",
"-l",
type=str,
help="The directory where the swanlab log files are stored.",
)
def convert(convert_dir: str, type: str, project: str, cloud: bool, workspace: str, logdir: str, **kwargs):
"""Convert the log files of other experiment tracking tools to SwanLab."""
if type == "tensorboard":
from swanlab.converter import TFBConverter

tfb_converter = TFBConverter(
convert_dir=convert_dir,
project=project,
workspace=workspace,
cloud=cloud,
logdir=logdir,
)
tfb_converter.run()
else:
raise ValueError("The type of the experiment tracking tool you want to convert to is not supported.")


if __name__ == "__main__":
cli()
1 change: 1 addition & 0 deletions swanlab/converter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .tfb import TFBConverter
1 change: 1 addition & 0 deletions swanlab/converter/tfb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .tfb_converter import TFBConverter
116 changes: 116 additions & 0 deletions swanlab/converter/tfb/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
from PIL import Image
import io
import numpy as np

try:
import tensorflow as tf
except ImportError as e:
raise TypeError(
"Tensorboard Converter requires tensorflow when process tfevents file. Install with 'pip install tensorflow'."
)


def get_tf_events_tags_type(tf_event_path: str):
"""获取TFEvent文件中所有tag的类型,并返回一个字典
比如{"tag1": "scalar", "tag2": "image", "tag3": "audio", "tag4": "text"}

Args:
tf_event_path (_type_): 单个tf_event文件的路径

Returns:
_type_: 返回一个字典,键是tag,值是tag的类型
"""
# 确保路径存在
assert os.path.exists(tf_event_path), "TFEvent file does not exist"

# 用来存储所有tag的集合和类型
tags = {}

# 遍历所有事件
for event in tf.compat.v1.train.summary_iterator(tf_event_path):
for value in event.summary.value:
if value.tag not in tags:
if value.HasField("simple_value"):
tags[value.tag] = "scalar"
elif value.HasField("image"):
tags[value.tag] = "image"
elif value.HasField("audio"):
tags[value.tag] = "audio"
elif value.HasField("tensor") and value.tensor.dtype == tf.string:
tags[value.tag] = "text"

return tags


def get_tf_events_tags_data(tf_event_path: str, tags: dict):
"""获取TFEvent文件中所有tag的数据,并返回一个字典
比如{"tag1": [(step1, value1), (step2, value2)], "tag2": [(step1, value1), (step2, value2)]}

Args:
tf_event_path (str): tf_event文件的路径
tags (dict): tag的类型字典, 形如{"loss": "scalar", "np_image": "image"}

Returns:
_type_: 返回一个字典,键是tag,值是完整的数据, 形如{"loss": [(step1, value1), (step2, value2)], "np_image": [(step1, value1), (step2, value2)]}
"""

# 用来存储每个tag的数据
tag_data = {tag: [] for tag in tags}

# 再次遍历文件,这次是为了提取数据
for event in tf.compat.v1.train.summary_iterator(tf_event_path):
# wall_time = datetime.datetime.fromtimestamp(event.wall_time).strftime("%Y-%m-%d %H:%M:%S")
wall_time = int(event.wall_time)
for value in event.summary.value:
if value.tag in tag_data:
if tags[value.tag] == "scalar":
tag_data[value.tag].append((event.step, value.simple_value, wall_time))
elif tags[value.tag] == "image" and value.HasField("image"):
img_str = value.image.encoded_image_string
image = Image.open(io.BytesIO(img_str))
tag_data[value.tag].append((event.step, image, wall_time))
elif tags[value.tag] == "audio" and value.HasField("audio"):
audio = value.audio
audio_np = np.frombuffer(value.audio.encoded_audio_string, dtype=np.int16)
sample_rate = audio.sample_rate

tag_data[value.tag].append((event.step, [audio_np, int(sample_rate)], wall_time))
elif tags[value.tag] == "text" and value.HasField("tensor"):
text = tf.make_ndarray(value.tensor).item().decode("utf-8")
tag_data[value.tag].append((event.step, text, wall_time))

return tag_data


def find_tfevents(logdir: str, depth: int = 3):
"""查找指定目录下的所有tfevents文件

Args:
logdir (str): 日志文件夹路径
depth (int, optional): 目录深度,默认为3

Returns:
Dict: 返回一个字典,键是子目录名,值是该目录下的所有tfevents文件路径列表
"""
tfevents_dict = {}

def get_depth(path):
return path.count(os.sep)

base_depth = get_depth(logdir)

for root, dirs, files in os.walk(logdir):
current_depth = get_depth(root) - base_depth
if current_depth > depth:
continue

for file in files:
if "tfevents" in file:
# 使用目录名作为字典的键
directory_key = os.path.relpath(root, logdir) # 获取目录的最后一个部分作为键
if directory_key not in tfevents_dict:
tfevents_dict[directory_key] = []
tfevents_dict[directory_key].append(os.path.join(root, file))

return tfevents_dict
106 changes: 106 additions & 0 deletions swanlab/converter/tfb/tfb_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
ISSUE: https://github.com/SwanHubX/SwanLab/issues/437
"""

import os
import swanlab
from datetime import datetime
from ._utils import find_tfevents, get_tf_events_tags_type, get_tf_events_tags_data
from swanlab.log import swanlog as swl


class TFBConverter:
def __init__(
self,
convert_dir: str,
project: str = None,
workspace: str = None,
config: dict = None,
cloud: bool = True,
logdir: str = None,
**kwargs,
):
self.convert_dir = convert_dir
self.project = project
self.workspace = workspace
self.cloud = cloud
self.config = config
self.logdir = logdir

def run(self, depth=3):
swl.info("Start converting TFEvent files to SwanLab format...")
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

# 找到所有TFEvent文件, 生成一个路径字典
path_dict = find_tfevents(self.convert_dir, depth=depth)
if path_dict:
swl.info("Found TFEvent file path dictionary.")
else:
swl.error(f"No TFEvent file found in {self.convert_dir}, please check the path.")
return

for dir, paths in path_dict.items():
for path in paths:
filename = os.path.basename(path)

"""
获取所有的tag与其对应的类型, example:
type_by_tags = {'training/loss': 'scalar', 'fake image': 'image', 'fake audio': 'audio', 'fake text/text_summary': 'text'}
"""
type_by_tags = get_tf_events_tags_type(path)

# 如果有tag(即该日志文件有记录指标,而非空文件)
if type_by_tags:
# 初始化一个SwanLab实验
run = swanlab.init(
project=(f"Tensorboard-Conversion-{timestamp}" if self.project is None else self.project),
experiment_name=f"{dir}/{filename}",
workspace=self.workspace,
config={"tfevent_path": path},
cloud=self.cloud,
logdir=self.logdir,
)

if self.config:
run.config.update(self.config)

"""
根据tag提取数据, 格式为{tag: [(step, value, wall_time), ...]}, example:
data_by_tags = {
'training_loss': [
(0, 0.0, 1715839693),
(1, 0.019999999552965164, 1715839711),
(2, 0.03999999910593033, 1715839717)
],
...
}
"""
data_by_tags = get_tf_events_tags_data(path, type_by_tags)

times = []
# 遍历数据
if data_by_tags:
# 打印并转换数据到SwanLab
for tag, data in data_by_tags.items():
for step, value, time in data:
times.append(time)
# 如果是标量
if type_by_tags[tag] == "scalar":
swanlab.log({tag: value}, step=step)
# 如果是图片
elif type_by_tags[tag] == "image":
swanlab.log({tag: swanlab.Image(value)}, step=step)
# 如果是音频
elif type_by_tags[tag] == "audio":
swanlab.log({tag: swanlab.Audio(value[0], sample_rate=value[1])}, step=step)
# 如果是文本
elif type_by_tags[tag] == "text":
swanlab.log({tag: swanlab.Text(value)}, step=step)
# TODO: 随着SwanLab的发展,支持转换更多类型

# 计算完整的运行时间
runtime = max(times) - min(times)
swanlab.config.update({"RunTime": runtime})

# 结束当前实验
run.finish()