Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions swanlab/data/run/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ..modules import BaseType, DataType
from ...log import swanlog
from typing import Dict
from .utils import create_time, check_key_format, get_a_lock
from .utils import create_time, check_tag_format, get_a_lock
from urllib.parse import quote
import ujson
import os
Expand Down Expand Up @@ -42,7 +42,12 @@ def add(self, tag: str, data: DataType, step: int = None):
步数,如果不传则默认当前步数为'已添加数据数量+1'
在log函数中已经做了处理,此处不需要考虑数值类型等情况
"""
check_key_format(tag)
key = tag
tag = check_tag_format(key, auto_cut=True)
if key != tag:
# 超过255字符,截断
swanlog.warning(f"Tag {key} is too long, cut to 255 characters.")

if isinstance(data, BaseType):
# 注入一些内容
data.settings = self.settings
Expand Down
2 changes: 1 addition & 1 deletion swanlab/data/run/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
@Description:
运行时工具函数
"""
from ...utils.file import check_key_format, get_a_lock, check_exp_name_format, check_desc_format
from ...utils.file import check_tag_format, get_a_lock, check_exp_name_format, check_desc_format
from ...utils import get_package_version, create_time, generate_color
18 changes: 13 additions & 5 deletions swanlab/utils/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,32 @@ def get_a_lock(file_path: str, mode: str = "r+", encoding="utf-8") -> TextIOWrap
return f


def check_key_format(key: str) -> str:
"""检查key字符串格式,必须是0-9a-zA-Z _-和/组成的字符串,并且开头必须是0-9a-zA-Z
def check_tag_format(key: str, auto_cut=True) -> str:
"""检查tag字符串格式,必须是0-9a-zA-Z _-和/组成的字符串(包含空格),并且开头必须是0-9a-zA-Z
最大长度为255字符

Parameters
----------
key : str
待检查的字符串
"""
max_len = 255
if not isinstance(key, str):
raise TypeError(f"key: {key} is not a string")
raise TypeError(f"tag: {key} is not a string")
# 定义正则表达式
pattern = re.compile("^[0-9a-zA-Z][0-9a-zA-Z_/-]*$")
pattern = re.compile("^[0-9a-zA-Z][0-9a-zA-Z _/-]*$")

# 检查 key 是否符合规定格式
if not pattern.match(key):
raise ValueError(
f"key: {key} is not a valid string, which must be composed of 0-9a-zA-Z _- and /, and the first character must be 0-9a-zA-Z"
f"tag: {key} is not a valid string, which must be composed of 0-9a-zA-Z _- and /, and the first character must be 0-9a-zA-Z"
)

# 检查长度
if auto_cut and len(key) > max_len:
key = key[:max_len]
elif not auto_cut and len(key) > max_len:
raise IndexError(f"tag: {key} is too long, which must be less than {max_len} characters")
return key


Expand Down