Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions swanlab/data/run/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,15 @@ def mock_from_remote(
f"Maybe you need to update swanlab: pip install -U swanlab"
)
error = ParseErrorInfo(expected=expected, got=got, chart=chart)

# 3. 根据 column_class 和 chart 适配 section_type
if column_class == "SYSTEM":
section_type: SectionType = "SYSTEM"
else:
if chart is ChartType.ECHARTS:
section_type = "CUSTOM"
else:
section_type = "PUBLIC"
# 4. 创建 ColumnInfo 对象
column_info = ColumnInfo(
key,
str(kid),
Expand All @@ -291,10 +299,10 @@ def mock_from_remote(
chart_reference="STEP",
error=error,
section_name=None,
section_type="PUBLIC",
section_type=section_type,
)
key_obj.column_info = column_info
# 3. 设置当前步数,resume 后不允许设置历史步数,所以需要覆盖
# 5. 设置当前步数,resume 后不允许设置历史步数,所以需要覆盖
if step is not None:
for i in range(step + 1):
key_obj.steps.add(i)
Expand Down
16 changes: 11 additions & 5 deletions test/resume/must.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@
run = swanlab.init()
swanlab.log({"loss": 0.1, "accuracy": 0.9}, step=1)

# 2. 继续一个不存在的实验
try:
run = swanlab.init(id="".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=21)), resume='must')
except RuntimeError:
pass
import time

time.sleep(5)
Expand All @@ -33,3 +28,14 @@
assert not ll['accuracy'].is_error, "Expected accuracy metric to be logged successfully after reinit"
assert ll['loss'].data == 0.5
assert ll['accuracy'].data == 0.8

# 2. 继续一个不存在的实验
try:
run = swanlab.init(
id="".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=21)),
resume='must',
reinit=True,
)
raise Exception("Should have raised exception")
except RuntimeError:
pass
53 changes: 53 additions & 0 deletions test/unit/data/run/test_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,59 @@ class TestMockKey:
测试模拟一个 SwanLabKey 对象
"""

@pytest.mark.parametrize(
"column_class, expected_section_type",
[
["CUSTOM", "PUBLIC"],
["SYSTEM", "SYSTEM"],
],
)
def test_column_class_ok(self, column_class, expected_section_type):
"""
测试不同的 column_class 是否能正确映射到预期的 SectionType
"""
with UseMockRunState() as run_state:
key_obj, column_info = SwanLabKey.mock_from_remote(
key="test",
column_type="FLOAT",
column_class=column_class,
error=None,
media_dir=run_state.store.media_dir,
log_dir=run_state.store.log_dir,
kid=0,
step=None,
)
assert column_info.section_type == expected_section_type

@pytest.mark.parametrize(
"column_type, expected_section_type",
[
["FLOAT", "PUBLIC"],
["IMAGE", "PUBLIC"],
["AUDIO", "PUBLIC"],
["TEXT", "PUBLIC"],
["OBJECT3D", "PUBLIC"],
["MOLECULE", "PUBLIC"],
["ECHARTS", "CUSTOM"],
],
)
def test_column_type_section_type(self, column_type, expected_section_type):
"""
测试不同的 column_type 是否能正确映射到预期的 SectionType
"""
with UseMockRunState() as run_state:
key_obj, column_info = SwanLabKey.mock_from_remote(
key="test",
column_type=column_type,
column_class="CUSTOM",
error=None,
media_dir=run_state.store.media_dir,
log_dir=run_state.store.log_dir,
kid=0,
step=None,
)
assert column_info.section_type == expected_section_type

@pytest.mark.parametrize(
"column_type, expected_chart_type",
[
Expand Down