Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions swanlab/converter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,20 @@
@Description:
转换部分,兼容其他可视化工具,转换为swanlab格式
"""
from .tfb import TFBConverter
from .wb import WandbConverter


# 使用延迟导入的方式
def __getattr__(name):
if name == "TFBConverter":
from .tfb import TFBConverter

return TFBConverter
if name == "WandbConverter":
from .wb import WandbConverter

return WandbConverter

raise AttributeError(f"module 'convert' has no attribute '{name}'")


__all__ = ["TFBConverter", "WandbConverter"]
25 changes: 16 additions & 9 deletions swanlab/converter/tfb/_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import os
from PIL import Image
import io
import numpy as np

try:
import tensorflow as tf
except ImportError as e:
raise TypeError(
"Tensorboard Converter requires tensorflow when process tfevents file. Install with 'pip install tensorflow'."
)

def get_tf_events_tags_type(tf_event_path: str):
try:
import tensorflow as tf
except ImportError as e:
raise TypeError(
"Tensorboard Converter requires tensorflow when process tfevents file. Install with 'pip install tensorflow'."
)

def get_tf_events_tags_type(tf_event_path: str):
"""获取TFEvent文件中所有tag的类型,并返回一个字典
比如{"tag1": "scalar", "tag2": "image", "tag3": "audio", "tag4": "text"}

Expand Down Expand Up @@ -44,6 +42,15 @@ def get_tf_events_tags_type(tf_event_path: str):


def get_tf_events_tags_data(tf_event_path: str, tags: dict):
try:
import tensorflow as tf
except ImportError as e:
raise TypeError(
"Tensorboard Converter requires tensorflow when process tensorboard tfevents file. Install with 'pip install tensorflow'."
)
import numpy as np
from PIL import Image

"""获取TFEvent文件中所有tag的数据,并返回一个字典
比如{"tag1": [(step1, value1), (step2, value2)], "tag2": [(step1, value1), (step2, value2)]}

Expand Down
7 changes: 6 additions & 1 deletion swanlab/converter/wb/wb_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ def parse_wandb_logs(self, wb_project: str, wb_entity: str, wb_run_id: str = Non
swanlab_run.config.update(wb_config)
swanlab_run.config.update(wb_run.config)

keys = [key for key in wb_run.history(stream="default").keys() if not key.startswith("_")]
# Get the first history record to extract available keys
history = wb_run.history(stream="default")
if len(history) > 0:
keys = [key for key in history[0].keys() if not key.startswith("_")]
else:
keys = []

# 记录标量指标
for record in wb_run.scan_history():
Expand Down