Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions swanlab/data/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,45 @@
from swankit.core.data import BaseType, MediaBuffer, MediaType

from .audio import Audio
from .echarts import echarts, Echarts, PyEchartsBase
from .custom_charts import echarts, Echarts, PyEchartsBase, PyEchartsTable
from .image import Image
from .line import FloatConvertible, Line
from .object3d import Model3D, Object3D, PointCloud, Molecule
from .object3d import Object3D, Molecule
from .text import Text
from .wrapper import DataWrapper

DataType = Union[int, float, FloatConvertible, BaseType, List[BaseType]]
DataType = Union[
int,
float,
FloatConvertible,
BaseType,
List[BaseType],
PyEchartsBase,
PyEchartsTable,
List[PyEchartsTable],
List[PyEchartsBase],
]

ChartType = BaseType.Chart

__all__ = [
# 数据类型
"FloatConvertible",
"DataWrapper",
"MediaType",
"PyEchartsBase",
"PyEchartsTable",
"DataType",
"ChartType",
"MediaBuffer",
# 支持的图表类
"Image",
"Audio",
"Text",
"Line",
"DataType",
"ChartType",
"MediaBuffer",
"Object3D",
"PointCloud",
"Model3D",
"Molecule",
"Echarts",
# 模块
"echarts",
"PyEchartsBase",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,26 @@
@author: ComPleHN
@file: __init__.py
@time: 2025/5/19 14:01
@desc: 集成 pyecharts
@desc: 自定义图表,目前集成了 echarts
"""

import pyecharts
from pyecharts.charts.base import Base
from swankit.core import MediaBuffer, DataSuite as D, MediaType

echarts = pyecharts.charts
from . import echarts
from .table import Table

PyEchartsBase = pyecharts.charts.base.Base
"""
pyecharts.charts.base.Base 的别名
pyecharts.charts.base.Base
"""
PyEchartsTable = Table
"""
custom Table inherited from pyecharts.components.table.Table
"""

__all__ = ["echarts", 'Echarts', 'PyEchartsBase']
__all__ = ["echarts", 'Echarts', 'PyEchartsTable', 'PyEchartsBase']


class Echarts(MediaType):
Expand Down
43 changes: 43 additions & 0 deletions swanlab/data/modules/custom_charts/echarts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
@author: cunyue
@file: echarts.py
@time: 2025/6/11 11:37
@description: 由于前端支持有限,这里仅导出部分 echarts 类型
"""

from pyecharts import options
from pyecharts.charts import *

from .table import Table

__all__ = [
"options",
"Bar3D",
"Bar",
"Boxplot",
"Calendar",
"Kline",
"Grid",
"Scatter",
"EffectScatter",
"Funnel",
"Gauge",
"Graph",
"HeatMap",
"Line",
"Line3D",
"Liquid",
"Parallel",
"PictorialBar",
"Pie",
"Polar",
"Radar",
"Sankey",
"Scatter3D",
"Sunburst",
"Surface3D",
"ThemeRiver",
"Tree",
"TreeMap",
"Table",
]
72 changes: 72 additions & 0 deletions swanlab/data/modules/custom_charts/table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
@author: cunyue
@file: table.py
@time: 2025/6/11 11:43
@description: 额外适配 pyecharts 的 table 类型,前端重写,自定义数据结构
"""

import datetime

import simplejson as json
from pyecharts.commons import utils
from pyecharts.components import Table as T
from pyecharts.options import ComponentTitleOpts
from pyecharts.options.series_options import BasicOpts
from pyecharts.types import Sequence, Union, Optional


class Table(T):
def __init__(self):
super().__init__()
self.options: dict = {'_swanLab': "table", 'headers': [], 'rows': []}

def add(self, headers: Sequence, rows: Sequence, attributes: Optional[dict] = None):
super().add(headers, rows, attributes)
self.options.update({'headers': headers, 'rows': rows})
return self

def get_table_format(self) -> dict:
"""获取转换后的表格格式(包含rowData和colDefs)"""
original_data = utils.remove_key_with_none_value(self.options)

# 创建新字典,保留所有原始字段
converted_data = original_data.copy()

# 转换表格数据部分
if "headers" in original_data and "rows" in original_data:
# 转换列定义
converted_data["colDefs"] = [{"field": header} for header in original_data["headers"]]

# 转换行数据
converted_data["rowData"] = [dict(zip(original_data["headers"], row)) for row in original_data["rows"]]

# 移除原始的headers和rows字段
converted_data.pop("headers", None)
converted_data.pop("rows", None)

return converted_data

def set_global_opts(self, title_opts: Union[ComponentTitleOpts, dict, None] = None):
raise NotImplementedError("set_global_opts is not supported in swanlab.echarts.Table")

@staticmethod
def _default_parse(o):
"""
默认的序列化方法,处理日期、JsCode和BasicOpts等特殊类型
"""
if isinstance(o, (datetime.date, datetime.datetime)):
return o.isoformat()
if isinstance(o, utils.JsCode):
return o.replace("\\n|\\t", "").replace(r"\\n", "\n").replace(r"\\t", "\t").js_code
if isinstance(o, BasicOpts):
if isinstance(o.opts, Sequence):
return [utils.remove_key_with_none_value(item) for item in o.opts]
else:
return utils.remove_key_with_none_value(o.opts)
return str(o) # 对于其他类型,转换为字符串

def dump_options(self) -> str:
"""序列化原始格式的options"""
return utils.replace_placeholder(
json.dumps(self.get_table_format(), default=self._default_parse, ignore_nan=True)
)
2 changes: 1 addition & 1 deletion swanlab/data/modules/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from swankit.core import ParseResult, ParseErrorInfo, MediaType, ChartReference

from swanlab.error import DataTypeError
from .echarts import Echarts
from .custom_charts import Echarts
from .line import Line


Expand Down
8 changes: 4 additions & 4 deletions swanlab/data/run/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from swankit.core.data import MediaType

from swanlab.data import namer as N
from swanlab.data.modules import DataWrapper, FloatConvertible, Line, Echarts, PyEchartsBase
from swanlab.data.modules import DataWrapper, FloatConvertible, Line, Echarts, PyEchartsBase, PyEchartsTable
from swanlab.env import get_mode, get_swanlog_dir
from swanlab.log import swanlog
from swanlab.package import get_package_version
Expand Down Expand Up @@ -425,7 +425,7 @@ def log(self, data: dict, step: int = None):
# 输入为可转换为float的数据类型
if isinstance(v, (int, float, FloatConvertible)):
v = DataWrapper(k, [Line(v)])
elif isinstance(v, PyEchartsBase):
elif isinstance(v, (PyEchartsBase, PyEchartsTable)):
v = DataWrapper(k, [Echarts(v)])
# 为Line类型或者MediaType类型
elif isinstance(v, (Line, MediaType)):
Expand All @@ -434,14 +434,14 @@ def log(self, data: dict, step: int = None):
elif (
isinstance(v, list)
and len(v) > 0
and all([isinstance(i, (Line, MediaType, PyEchartsBase)) for i in v])
and all([isinstance(i, (Line, MediaType, PyEchartsBase, PyEchartsTable)) for i in v])
and all([i.__class__ == v[0].__class__ for i in v])
):
if len(v) > MAX_LIST_LENGTH:
swanlog.warning(f"List length '{k}' is too long, cut to {MAX_LIST_LENGTH}.")
v = v[:MAX_LIST_LENGTH]
# echarts 类型需要转换
if isinstance(v[0], PyEchartsBase):
if isinstance(v[0], (PyEchartsBase, PyEchartsTable)):
v = DataWrapper(k, [Echarts(i) for i in v])
else:
v = DataWrapper(k, v)
Expand Down
4 changes: 2 additions & 2 deletions swanlab/data/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from swanlab.swanlab_settings import Settings, get_settings, set_settings
from .callbacker.cloud import CloudRunCallback
from .formatter import check_load_json_yaml, check_callback_format
from .modules import DataType, PyEchartsBase
from .modules import DataType
from .run import (
SwanLabRunState,
SwanLabRun,
Expand Down Expand Up @@ -249,7 +249,7 @@ def init(

@should_call_after_init("You must call swanlab.init() before using log()")
def log(
data: Dict[str, Union[DataType, PyEchartsBase, List[PyEchartsBase]]],
data: Dict[str, DataType],
step: int = None,
print_to_console: bool = False,
):
Expand Down
9 changes: 6 additions & 3 deletions test/metrics/echarts/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# echarts测试需知
# echarts测试

# NOTE 你可能需要下载对应的文件才可以运行测试
这里是根据 pyecharts-gallery 中的 demo, 对 pyecharts 模块的遍历测试,不保证 swanlab 能够完全渲染所有的图表。

> 如果你有相关图表的需求,可以向我们提交 [issue](https://github.com/SwanHubX/SwanLab/issues)

---
## 需要下载额外文件的测试

> NOTE 你可能需要下载对应的文件才可以运行测试

- `dataset.py`测试:需要到[这里](https://github.com/pyecharts/pyecharts-gallery/blob/master/Dataset/life-expectancy-table.json),下载文件`life-expectancy-table.json`到本目录下的`assets/echarts/dataset`目录下,才能运行`dataset.py`测试
- `graph.py`测试:需要到[这里](https://github.com/pyecharts/pyecharts-gallery/blob/master/Graph/les-miserables.json),[这里](https://github.com/pyecharts/pyecharts-gallery/blob/master/Graph/weibo.json),下载两个文件`les-miserables.json`,`weibo.json`到本目录下的`assets/echarts/graph`目录下,才能运行`graph.py`测试
- `pictorialbar.py`测试:需要到[这里](https://github.com/pyecharts/pyecharts-gallery/blob/master/PictorialBar/symbol.json),下载文件`symbol.json`,到本目录下的`assets/echarts/pictorialbar`目录下,才能运行`pictorialbar.py`测试
Expand Down
29 changes: 10 additions & 19 deletions test/metrics/echarts/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,22 @@
// @time: 2025/5/29 10:46
// @description: 本文件是对于echarts的 表格组件 图表的测试
"""

# ---------------------------------------------- Table - Table_base ----------------------------------------------
from pyecharts.components import Table
from pyecharts.options import ComponentTitleOpts
from swanlab import echarts


table = Table()
table = echarts.Table()

headers = ["City name", "Area", "Population", "Annual Rainfall"]
headers = ["NO", "Product", "Count"]
rows = [
["Brisbane", 5905, 1857594, 1146.4],
["Adelaide", 1295, 1158259, 600.5],
["Darwin", 112, 120900, 1714.7],
["Hobart", 1357, 205556, 619.5],
["Sydney", 2058, 4336374, 1214.8],
["Melbourne", 1566, 3806092, 646.9],
["Perth", 5386, 1554769, 869.4],
[2, "A", 259],
[3, "B", 123],
[4, "C", 300],
[5, "D", 290],
[6, "E", 1145],
]
table.add(headers, rows)
table.set_global_opts(
title_opts=ComponentTitleOpts(title="Table-基本示例", subtitle="我是副标题支持换行哦")
)

c1 = table

Expand All @@ -36,8 +31,4 @@
public=True,
)

swanlab.log(
{
"Table - Table_base": c1
}
)
swanlab.log({"Table - Table_base": c1})
47 changes: 47 additions & 0 deletions test/unit/data/modules/echarts/test_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import json

import pytest

from swanlab.data.modules.custom_charts.table import Table


def test_table():
"""测试表格功能"""
# 创建表格对象
table = Table()
assert isinstance(table, Table)

# 添加数据
headers = ["姓名", "年龄", "城市"]
rows = [["张三", 25, "北京"], ["李四", 30, "上海"]]
table.add(headers, rows)

# 不允许设置标题
with pytest.raises(NotImplementedError):
table.set_global_opts({"title": "测试表格"})

# 测试格式转换
formatted_data = table.get_table_format()
assert isinstance(formatted_data, dict)
assert "_swanLab" in formatted_data
assert formatted_data["_swanLab"] == "table"
assert "colDefs" in formatted_data
assert "rowData" in formatted_data

# 使用标准json序列化测试数据结构
json_str = json.dumps(formatted_data)
assert isinstance(json_str, str)

# 可解析为字典
table_data = json.loads(json_str)
assert isinstance(table_data, dict)
assert table_data["_swanLab"] == "table"

# 验证列定义
assert len(table_data["colDefs"]) == 3
assert table_data["colDefs"][0]["field"] == "姓名"

# 验证行数据
assert len(table_data["rowData"]) == 2
assert table_data["rowData"][0]["姓名"] == "张三"
assert table_data["rowData"][0]["年龄"] == 25
3 changes: 2 additions & 1 deletion test/unit/data/test_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ def test_init_logdir_settings_backup(self):
run = S.init(mode="cloud", settings=Settings(backup=False))
assert run.public.swanlog_dir != logdir
assert run.public.swanlog_dir == os.path.join(os.getcwd(), "swanlog")
assert os.path.exists(run.public.swanlog_dir)
# 此时文件夹不存在,因为关闭了backup功能
assert not os.path.exists(run.public.swanlog_dir)


@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test")
Expand Down
Loading