Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a896ceb
Return response object in Client HTTP methods
SAKURA-CAT Jun 26, 2025
76b778a
Enhance experiment management in Client and add tests
SAKURA-CAT Jun 26, 2025
c9523c6
tmp
SAKURA-CAT Jun 26, 2025
45ed469
Refactor metrics upload to handle pending state
SAKURA-CAT Jun 26, 2025
cc98a8e
Fix login info handling and add re-init tests
SAKURA-CAT Jun 26, 2025
d339a6a
Refactor CloudPyCallback client creation and add tests
SAKURA-CAT Jun 26, 2025
1e9312a
Add run_id generation and validation for runs
SAKURA-CAT Jun 26, 2025
7177b07
Refactor experiment resume logic and add must_exist checks
SAKURA-CAT Jun 26, 2025
eb2338a
Add run_id property to SwanLabRun and related tests
SAKURA-CAT Jun 27, 2025
1d25951
tmp stash
SAKURA-CAT Jun 29, 2025
3a4ab9a
Refactor upload functions and add flagId to session headers
SAKURA-CAT Jul 1, 2025
ec98a9e
Refactor SwanLabKey to separate module and update tests
SAKURA-CAT Jul 2, 2025
755f150
Add mock_from_remote to SwanLabKey and unit tests
SAKURA-CAT Jul 3, 2025
9a3c8cc
Add experiment new/existing status tracking and assertions
SAKURA-CAT Jul 3, 2025
4ca178c
Export is_system_key and add to __all__ in hardware module
SAKURA-CAT Jul 4, 2025
70a1d12
Support experiment resume with remote metric and log sync
SAKURA-CAT Jul 4, 2025
e7c3441
Refactor cloud callback and update test cases
SAKURA-CAT Jul 4, 2025
c638a31
Rename run_id parameter to id in SwanLabInitializer
SAKURA-CAT Jul 4, 2025
2d8bf09
Refactor run_id to id and validate run id format
SAKURA-CAT Jul 4, 2025
26e52a7
Update error message regex in run ID format tests
SAKURA-CAT Jul 4, 2025
a728546
Update usage of run_id to run.id in tests
SAKURA-CAT Jul 4, 2025
a5d7882
Add support for RUN_ID and RESUME env variables
SAKURA-CAT Jul 4, 2025
b2d81f4
Fix project ID assignment in create_data function
SAKURA-CAT Jul 5, 2025
333ebe4
Add tests for 'allow' and 'must' resume modes
SAKURA-CAT Jul 5, 2025
3332e35
Add and expand tests for resume functionality
SAKURA-CAT Jul 5, 2025
0a12b8d
Improve error messages for experiment resume failures
SAKURA-CAT Jul 5, 2025
7eb8216
Combine ValueError message for cloned experiment
SAKURA-CAT Jul 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 102 additions & 30 deletions swanlab/core_python/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,19 @@ def __init__(self, login_info: auth.LoginInfo):
self.__version = get_package_version()
# 创建会话
self.__create_session()
# 标识当前实验会话(flagId)是否被其他进程顶掉
self.pending = False

# ---------------------------------- 一些辅助属性 ----------------------------------
@property
def exp(self) -> ExperimentInfo:
assert self.__exp is not None, "Experiment not mounted, please call mount_exp() first"
return self.__exp

@property
def proj(self) -> ProjectInfo:
assert self.__proj is not None, "Project not mounted, please call mount_project() first"
return self.__proj

@property
def base_url(self):
Expand Down Expand Up @@ -100,10 +111,6 @@ def username(self):
def cos(self):
return self.__cos

@property
def proj_id(self):
return self.__proj.cuid

@property
def projname(self):
return self.__proj.name
Expand Down Expand Up @@ -148,6 +155,10 @@ def __before_request(self):
self.__login_info = auth.login_by_key(self.__login_info.api_key, save=False)
self.__session.headers["cookie"] = f"sid={self.__login_info.sid}"

# 携带实验会话Id
if self.__exp is not None:
self.__session.headers["flagId"] = self.__exp.flag_id

def __create_session(self):
"""
创建会话,这将在HTTP类实例化时调用
Expand Down Expand Up @@ -188,7 +199,7 @@ def post(self, url: str, data: Union[dict, list] = None):
url = self.base_url + url
self.__before_request()
resp = self.__session.post(url, json=data)
return decode_response(resp)
return decode_response(resp), resp

def put(self, url: str, data: dict = None):
"""
Expand All @@ -197,7 +208,7 @@ def put(self, url: str, data: dict = None):
url = self.base_url + url
self.__before_request()
resp = self.__session.put(url, json=data)
return decode_response(resp)
return decode_response(resp), resp

def get(self, url: str, params: dict = None):
"""
Expand All @@ -206,7 +217,7 @@ def get(self, url: str, params: dict = None):
url = self.base_url + url
self.__before_request()
resp = self.__session.get(url, params=params)
return decode_response(resp)
return decode_response(resp), resp

def patch(self, url: str, data: dict = None):
"""
Expand All @@ -215,13 +226,13 @@ def patch(self, url: str, data: dict = None):
url = self.base_url + url
self.__before_request()
resp = self.__session.patch(url, json=data)
return decode_response(resp)
return decode_response(resp), resp

# ---------------------------------- 对象存储方法 ----------------------------------
# ---------------------------------- 对象存储相关方法 ----------------------------------

def __get_cos(self):
self.__cos = CosClient(
data=self.get(f"/project/{self.groupname}/{self.projname}/runs/{self.exp_id}/sts"),
data=self.get(f"/project/{self.groupname}/{self.projname}/runs/{self.exp_id}/sts")[0],
)

def upload(self, buffer: MediaBuffer):
Expand Down Expand Up @@ -260,21 +271,23 @@ def mount_project(self, name: str, username: str = None, public: bool = None):
data["username"] = username
if public is not None:
data["visibility"] = "PUBLIC" if public else "PRIVATE"
resp = self.post(f"/project", data=data)
resp_data, _ = self.post(f"/project", data=data)
except ApiError as e:
if e.resp.status_code == 409:
# 项目已经存在,从对象中解析信息
resp = decode_response(e.resp)
resp_data = decode_response(e.resp)
elif e.resp.status_code == 404 and e.resp.reason == "Not Found":
# WARNING: 早期 (私有化) swanlab 后端没有 /project 接口,需要使用 /project/{username} 接口,此时没有默认空间的特性
self.__groupname = self.__groupname if username is None else username
try:
visibility = "PUBLIC" if public else "PRIVATE"
resp = self.post(f"/project/{self.groupname}", data={"name": name, "visibility": visibility})
resp_data, _ = self.post(
f"/project/{self.groupname}", data={"name": name, "visibility": visibility}
)
except ApiError as e:
# 如果为409,表示已经存在,获取项目信息
if e.resp.status_code == 409:
resp = self.get(f"/project/{self.groupname}/{name}")
resp_data, _ = self.get(f"/project/{self.groupname}/{name}")
elif e.resp.status_code == 404:
# 组织/用户不存在
raise ValueError(f"Space `{self.groupname}` not found")
Expand All @@ -283,47 +296,106 @@ def mount_project(self, name: str, username: str = None, public: bool = None):
raise ValueError(f"Space permission denied: " + self.groupname)
else:
raise e
return ProjectInfo(resp)
else:
# 此接口为后端处理,sdk 在理论上不会出现其他错误,因此不需要处理其他错误
raise e
# 设置当前项目所属的用户名
self.__groupname = resp['username']
self.__groupname = resp_data['username']
# 获取详细信息
resp = self.get(f"/project/{self.groupname}/{name}")
self.__proj = ProjectInfo(resp)

def mount_exp(self, exp_name, colors: Tuple[str, str], description: str = None, tags: List[str] = None):
resp_data_info, _ = self.get(f"/project/{self.groupname}/{name}")
self.__proj = ProjectInfo(resp_data_info)

def mount_exp(
self,
exp_name,
colors: Tuple[str, str],
description: str = None,
tags: List[str] = None,
created_at: str = None,
cuid: str = None,
must_exist: bool = False,
) -> bool:
"""
初始化实验,获取存储信息
:param exp_name: 所属实验名称
:param colors: 实验颜色,有两个颜色
:param description: 实验描述
:param tags: 实验标签
:param created_at: 实验创建时间,格式为 ISO 8601
:param cuid: 实验的唯一标识符,如果不提供则由后端生成
:param must_exist: 如果 cuid 被传递,是否限制实验必须存在

:raises RuntimeError: 如果实验不存在且must_exist为True
:raises NotImplementedError: 如果项目未挂载

:return: 返回实验为新建的还是更新的,为 True 时为新建实验
"""
if self.__proj is None:
raise NotImplementedError("Project not mounted, please call mount_project() first")
if must_exist:
assert cuid is not None, "cuid must be provided when must_exist is True"
try:
self.get(f"/project/{self.groupname}/{self.__proj.name}/runs/{cuid}")
except ApiError as e:
if e.resp.status_code == 404 and e.resp.reason == "Not Found":
raise RuntimeError(f"Experiment {cuid} does not exist in project {self.projname}")

labels = [{"name": tag} for tag in tags] if tags else []
post_data = {
"name": exp_name,
"description": description,
"createdAt": created_at,
"colors": list(colors),
"labels": labels if len(labels) else None,
"cuid": cuid,
}
if description is not None:
post_data["description"] = description
if tags is not None:
post_data["labels"] = [{"name": tag} for tag in tags]
post_data = {k: v for k, v in post_data.items() if v is not None} # 移除值为None的键

data = self.post(f"/project/{self.groupname}/{self.__proj.name}/runs", post_data)
# 这部分错误将不会被上层捕获,直接抛出异常
try:
data, resp = self.post(f"/project/{self.groupname}/{self.__proj.name}/experiment", post_data)
except ApiError as e:
if e.resp.status_code == 400 and e.resp.reason == "Bad Request":
# 指定的 cuid 对应的实验是克隆实验
raise ValueError(f"Experiment with cuid {cuid} is a cloned experiment and cannot be resumed")
elif e.resp.status_code == 403 and e.resp.reason == "Forbidden":
# 权限不足
raise ValueError(f"Project permission denied: {self.projname}")
elif e.resp.status_code == 404 and e.resp.reason == "Not Found":
# 传入的项目不存在
raise ValueError(f"Project {self.projname} not found")
elif e.resp.status_code == 404 and e.resp.reason == "Disabled Resource":
# 传入的实验被删除
raise ValueError(f"Experiment {cuid} has been deleted")
elif e.resp.status_code == 409 and e.resp.reason == "Conflict":
# 传入 cuid 但是实验不属于当前项目
raise ValueError(f"Experiment with cuid {cuid} does not belong to project {self.projname}")
raise e
# 200代表实验已存在,开启更新模式
# 201代表实验不存在,新建实验
new = resp.status_code == 201
# 这部分信息暂时没有用到
self.__exp = ExperimentInfo(data)
# 获取cos信息
self.__get_cos()
# 重置挂起状态
self.pending = False
return new

def update_state(self, success: bool):
def update_state(self, success: bool, finished_at: str = None):
"""
更新实验状态
:param success: 实验是否成功
:param finished_at: 实验结束时间,格式为 ISO 8601,如果不提供则使用当前时间
"""
self.put(
f"/project/{self.groupname}/{self.projname}/runs/{self.exp_id}/state",
{"state": "FINISHED" if success else "CRASHED", "from": "sdk"},
)
put_data = {
"state": "FINISHED" if success else "CRASHED",
"finishedAt": finished_at,
"from": "sdk",
}
put_data = {k: v for k, v in put_data.items() if v is not None} # 移除值为None的键
self.put(f"/project/{self.groupname}/{self.projname}/runs/{self.exp_id}/state", put_data)
self.pending = True


client: Optional["Client"] = None
Expand Down
30 changes: 30 additions & 0 deletions swanlab/core_python/client/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
@description: 实验、项目元信息
"""

from typing import Optional


class ProjectInfo:
def __init__(self, data: dict):
Expand All @@ -28,10 +30,38 @@ class ExperimentInfo:
def __init__(self, data: dict):
self.__data = data

@property
def flag_id(self):
"""
此实验的标志ID,标志上传时的实验会话
"""
return self.__data["flagId"]

@property
def cuid(self):
return self.__data["cuid"]

@property
def name(self):
return self.__data["name"]

@property
def config(self) -> dict:
"""
此实验的配置,用于 resume 时同步 config
"""
return self.__data.get("profile", {}).get("config", {})

@property
def root_proj_cuid(self) -> Optional[str]:
"""
根项目的cuid(上传实验时需要)
"""
return self.__data.get("rootProId", None)

@property
def root_exp_cuid(self) -> Optional[str]:
"""
根实验的cuid(上传实验时需要)
"""
return self.__data.get("rootExpId", None)
Loading
Loading