Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .prettierignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ components/style/color/*.less

.huskyrc

!README.md
!README_zh-hans.md

doc/**

**.md

!README.md
!README_zh-hans.md

.github/**

swanlab/
Expand Down
11 changes: 11 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
"editor.formatOnSave": true
},
/** TODO tree 配置 */
"todo-tree.general.tags": [
"TODO", // 待办
"FIXME", // 待修复
"COMPAT" // 兼容性问题
],
"todo-tree.highlights.customHighlight": {
"TODO": {
"icon": "check",
Expand All @@ -59,6 +64,12 @@
"type": "tag",
"foreground": "#ff0000",
"iconColour": "#ff0000"
},
"COMPAT": {
"icon": "flame",
"type": "tag",
"foreground": "#00ff00",
"iconColour": "#ffff"
}
},
/** i18n */
Expand Down
8 changes: 6 additions & 2 deletions swanlab/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def init(experiment_name: str = None, description: str = "", config: dict = {},
swl.info("Run `swanlab watch` to view SwanLab Experiment Dashboard")


def log(data: dict):
def log(data: dict, step: int = None):
"""以字典的形式记录数据,字典的key将作为列名,value将作为记录的值
例如:
```python
Expand All @@ -94,15 +94,19 @@ def log(data: dict):
----------
data : dict
此处填写需要记录的数据
step: int
当前记录的步数,如果不传则默认当前步数为'已添加数据数量+1'
"""
if sd is None:
raise RuntimeError("swanlab is not initialized")
if not isinstance(data, dict):
raise TypeError("log data must be a dict")
if step is not None and (not isinstance(step, int) or step < 0):
raise TypeError("'step' must be an integer not less than zero.")
for key in data:
# 遍历字典的key,记录到本地文件中
# TODO 检查数据类型
sd.add(key, data[key])
sd.add(key, data[key], step=step)


__all__ = ["log", "init"]
58 changes: 46 additions & 12 deletions swanlab/database/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from ..env import swc
from ..utils import create_time, get_a_lock
from typing import List, Union
from ..log import swanlog as swl


class ChartTable(ProjectTablePoxy):
"""图表管理类,用于管理图表,包括创建图表,删除图表,修改图表配置等操作"""

default_data = {"_sum": 0, "charts": []}
default_data = {"_sum": 0, "charts": [], "namespaces": []}

def __init__(self, experiment_id: int):
"""初始化图表管理类"""
Expand Down Expand Up @@ -41,6 +42,38 @@ def new_chart(self, chart_id: int, namespace: str, reference: str, chart_type: s
"create_time": create_time(),
}

def add_chart(self, data, chart):
"""添加图表到配置,同时更新组

Parameters
----------
data : dict
配置
chart : dict
添加的图表
"""
namespace: str = chart["namespace"]
namespaces: list = data["namespaces"]
data["charts"].append(chart)
# 遍历data["namespaces"]
ns: dict = None
for ns in namespaces:
if ns["namespace"] == namespace:
break
# 如果命名空间不存在,添加
if ns is None:
swl.debug(f"Namespace {namespace} not found, add.")
ns = {"namespace": "default", "charts": []}
if ns["namespace"] == "default":
swl.debug(f"Namespace {namespace} Add to the beginning")
namespaces.insert(0, ns)
else:
swl.debug(f"Namespace {namespace} Add to the end.")
namespaces.append(ns)
# 添加当前的chart_id到结尾
ns["charts"].append(chart["chart_id"])
swl.debug(f"Chart {chart['chart_id']} add, now charts: " + str(ns["charts"]))

def add(
self,
tag: Union[str, List[str]],
Expand All @@ -62,14 +95,15 @@ def add(
chart_type : str, optional
图表类型,用于区分不同的图表的显示方式,如折线图,柱状图等
"""
f = get_a_lock(self.path)
data = ujson.load(f)
# 记录数据
data["_sum"] += 1
chart = self.new_chart(data["_sum"], namespace, reference, chart_type)
chart["source"].append(tag)
data["charts"].append(chart)
f.truncate(0)
f.seek(0)
ujson.dump(data, f, ensure_ascii=False)
f.close()
with get_a_lock(self.path) as f:
data = ujson.load(f)
# 记录数据
data["_sum"] += 1
chart = self.new_chart(data["_sum"], namespace, reference, chart_type)
chart["source"].append(tag)
# 添加图表
self.add_chart(data, chart)
f.truncate(0)
f.seek(0)
ujson.dump(data, f, ensure_ascii=False)
f.close()
5 changes: 3 additions & 2 deletions swanlab/database/expriment.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,11 @@ def add(self, tag: str, data: Union[float, str], step: int = None):
# 在实验中记录此tag
self.tags.append({"tag": tag, "num": 0})
# 更新tag的数量,并拿到tag的索引
tag_index = self.update_tag_num(tag)
tag_num = self.update_tag_num(tag)
index = tag_num if step is None else step
# 往本地添加新的数据
# TODO 指定step支持 #9
self.save_tag(tag, data, self.experiment_id, tag_index, step=step)
self.save_tag(tag, data, self.experiment_id, index, tag_num)

def update_tag_num(self, tag: str) -> int:
for index, item in enumerate(self.tags):
Expand Down
7 changes: 4 additions & 3 deletions swanlab/database/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def experiment(self) -> ExperimentTable:
"""获取当前实验对象"""
return self.__project.experiment

def add(self, tag: str, data: Union[Image, float]):
def add(self, tag: str, data: Union[Image, float], step: int = None):
"""添加数据到数据库,保存数据,完成几件事情:
1. 如果{experiment_name}_{tag}表单不存在,则创建
2. 添加记录到{experiment_name}_{tag}表单中,包括create_time等
Expand All @@ -94,15 +94,16 @@ def add(self, tag: str, data: Union[Image, float]):
数据标签,用于区分同一资源下不同的数据
data : Union[str, float]
定位到的数据,暂时只支持str和float类型(事实上目前只支持float类型)
step: int
当前的步数,用于记录数据的顺序, 必须在传入之前进行类型检查,必须是一个不小于0的整数
"""
if self.__project is None:
raise RuntimeError("swanlab has not been initialized")

if isinstance(data, float):
# 如果是float类型,保留六位小数
data = round(data, 6)
# TODO 如果是Image类型,执行其他逻辑
self.__project.experiment.add(tag, data)
self.__project.experiment.add(tag, data, step=step)

def success(self):
"""标记实验成功"""
Expand Down
12 changes: 7 additions & 5 deletions swanlab/database/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(self, path: str):
def new_tag_data(self, index) -> dict:
"""创建一个新的data数据,实际上是一个字典,包含一些默认信息"""
return {
"index": index,
"index": str(index),
"create_time": create_time(),
}

Expand All @@ -135,7 +135,7 @@ def new_tag(self) -> dict:
"data": [],
}

def save_tag(self, tag: str, data: Any, experiment_id: int, index: int, **kwargs):
def save_tag(self, tag: str, data: Any, experiment_id: int, index: int, sum: int, **kwargs):
"""保存一个tag的数据

Parameters
Expand All @@ -148,6 +148,8 @@ def save_tag(self, tag: str, data: Any, experiment_id: int, index: int, **kwargs
实验id
index : int
tag索引
sum : int
当前tag总数
"""
# 创建一个新的tag数据
new_tag_data = self.new_tag_data(index)
Expand All @@ -162,9 +164,9 @@ def save_tag(self, tag: str, data: Any, experiment_id: int, index: int, **kwargs
if not os.path.exists(save_folder):
os.mkdir(save_folder)

# 优化文件分片,每__slice_size个tag数据保存为一个文件,通过index来判断
need_slice = (index - 1) % self.__slice_size == 0 or index == 1
mu = math.ceil(index / self.__slice_size)
# 优化文件分片,每__slice_size个tag数据保存为一个文件,通过sum来判断
need_slice = (sum - 1) % self.__slice_size == 0 or sum == 1
mu = math.ceil(sum / self.__slice_size)
# 存储路径
file_path = os.path.join(save_folder, str(mu * self.__slice_size) + ".json")
# 如果需要新增分片存储
Expand Down
24 changes: 24 additions & 0 deletions swanlab/server/api/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from urllib.parse import unquote # 转码路径参数
from typing import List, Dict
from ...utils import get_a_lock
from ...log import swanlog as swl

router = APIRouter()

Expand Down Expand Up @@ -277,3 +278,26 @@ async def get_recent_experiment_log(experiment_id: int):
data["error"] = error
# 返回最新的 MAX_NUM 条记录
return SUCCESS_200(data)


@router.get("/{experiment_id}/chart")
async def get_experimet_charts(experiment_id: int):
chart_path: str = os.path.join(swc.root, __find_experiment(experiment_id)["name"], "chart.json")
with get_a_lock(chart_path, "r+") as f:
chart: dict = ujson.load(f)
# COMPAT 如果chart不存在namespaces且charts有东西,生成它
compat = not chart.get("namespaces") and len(chart["charts"])
if compat:
# 提示用户,配置将更新
swl.warning(
"The configuration of the chart is somewhat outdated. SwanLab will automatically make some updates to this configuration."
)
# 遍历chart[charts],写入chart_id
charts = [c["chart_id"] for c in chart["charts"]]
ns = {"namespace": "default", "charts": charts}
chart["namespaces"] = [ns]
# 写入文件
f.truncate(0)
f.seek(0)
ujson.dump(chart, f, ensure_ascii=False)
return SUCCESS_200(chart)
9 changes: 6 additions & 3 deletions test/create_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import time

# 迭代次数
epochs = 500
epochs = 100
# 学习率
lr = 0.01
# 随机偏移量
Expand All @@ -36,8 +36,11 @@
acc = 1 - 2**-epoch - random.random() / epoch - offset
loss = 2**-epoch + random.random() / epoch + offset
print(f"epoch={epoch}, accuracy={acc}, loss={loss}")
sw.log({"loss": loss, "accuracy": acc, "loss2": loss, "accuracy2": acc, "loss3": loss, "accuracy3": acc})
time.sleep(0.1)
sw.log({"loss": loss, "accuracy": acc})
sw.log({"loss2": loss, "accuracy2": acc}, step=epochs - epoch)
sw.log({"loss3": loss, "accuracy3": acc}, step=epoch * 2)

time.sleep(0.5)
# if epoch % 40 == 0:
# epoch / 0

Expand Down
Loading