Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions swanlab/api/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,39 @@ def __init__(self, login_info: LoginInfo):
# 当前实验信息
self.__exp: Optional["ExperimentInfo"] = None

# 当前项目所属的username
self.__username = login_info.username

@property
def groupname(self):
"""
当前项目所属组名
"""
return self.__username

@property
def username(self):
"""
当前登录的用户名
"""
return self.__login_info.username

@property
def proj_id(self):
return self.__proj.cuid

@property
def projname(self):
return self.__proj.name

@property
def exp_id(self):
return self.__exp.cuid

@property
def expname(self):
return self.__exp.name

@property
def sid_expired_at(self):
"""
Expand Down Expand Up @@ -86,7 +107,7 @@ async def response_interceptor(response: httpx.Response):
"""
# 如果是
if response.status_code // 100 != 2:
raise ApiError(response)
raise ApiError(response, response.status_code, response.reason_phrase)

session.event_hooks['response'].append(response_interceptor)

Expand All @@ -109,7 +130,7 @@ async def get(self, url: str, params: dict = None) -> dict:
return resp.json()

async def __get_cos(self):
cos = await self.get(f'/project/{self.__login_info.username}/{self.__proj.name}/runs/{self.__exp.cuid}/sts')
cos = await self.get(f'/project/{self.groupname}/{self.projname}/runs/{self.exp_id}/sts')
self.__cos = CosClient(cos)

async def upload(self, key: str, local_path):
Expand Down Expand Up @@ -137,14 +158,22 @@ async def upload_files(self, keys: list, local_paths: list) -> Dict[str, Union[b
keys = [key[1:] if key.startswith('/') else key for key in keys]
return self.__cos.upload_files(keys, local_paths)

def mount_project(self, name: str):
def mount_project(self, name: str, username: str = None):
self.__username = self.__username if username is None else username

async def _():
try:
resp = await http.post(f'/project/{http.username}', data={'name': name})
resp = await http.post(f'/project/{self.groupname}', data={'name': name})
except ApiError as e:
# 如果为409,表示已经存在,获取项目信息
if e.resp.status_code == 409:
resp = await http.get(f'/project/{http.username}/{name}')
resp = await http.get(f'/project/{http.groupname}/{name}')
elif e.resp.status_code == 404:
# 组织/用户不存在
raise ValueError(f"Entity `{http.groupname}` not found")
elif e.resp.status_code == 403:
# 权限不足
raise ValueError(f"Entity permission denied: " + http.groupname)
else:
raise e
return ProjectInfo(resp)
Expand All @@ -162,11 +191,11 @@ def mount_exp(self, exp_name, colors: Tuple[str, str], description: str = None):

async def _():
"""
创建实验,生成
先创建实验,后生成cos凭证
:return:
"""
data = await self.post(
f'/project/{self.__login_info.username}/{self.__proj.name}/runs',
f'/project/{self.groupname}/{self.__proj.name}/runs',
{"name": exp_name, "colors": list(colors), "description": description}
)
self.__exp = ExperimentInfo(data)
Expand Down
30 changes: 15 additions & 15 deletions swanlab/data/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def init(
suffix: str = "default",
cloud: bool = True,
project: str = None,
organization: str = None,
entity: str = None,
load: str = None,
**kwargs,
) -> SwanLabRun:
Expand Down Expand Up @@ -126,14 +126,14 @@ def init(
cloud : bool, optional
Whether to use the cloud mode, the default is True.
If you use the cloud mode, the log file will be stored in the cloud, which will still be saved locally.
If you are not using cloud mode, the `project` and `organization` fields are invalid.
If you are not using cloud mode, the `project` and `entity` fields are invalid.
project : str, optional
The project name of the current experiment, the default is None,
which means the current project name is the same as the current working directory.
If you are using cloud mode, you must provide the project name.
organization : str, optional
The organization name of the current experiment, the default is None,
which means the log file will be stored in your personal space.
entity : str, optional
Where the current project is located, it can be an organization or a user (currently only supports yourself).
The default is None, which means the current entity is the same as the current user.
load : str, optional
If you pass this parameter,SwanLab will search for the configuration file you specified
(which must be in JSON or YAML format)
Expand Down Expand Up @@ -162,7 +162,7 @@ def init(
suffix = _load_data(load_data, "suffix", suffix)
cloud = _load_data(load_data, "cloud", cloud)
project = _load_data(load_data, "project", project)
organization = _load_data(load_data, "organization", organization)
entity = _load_data(load_data, "entity", entity)
# 初始化logdir参数,接下来logdir被设置为绝对路径且当前程序有写权限
logdir = _init_logdir(logdir)
# 初始化confi参数
Expand All @@ -175,13 +175,14 @@ def init(
exp_num = None
# ---------------------------------- 用户登录、格式、权限校验 ----------------------------------
global login_info
http = None
if login_info is None and cloud:
# 用户登录
login_info = _login_in_init()
# 初始化会话信息
http = create_http(login_info)
# 获取当前项目信息
http.mount_project(project)
http.mount_project(project, entity)

# 连接本地数据库,要求路径必须存在,但是如果数据库文件不存在,会自动创建
connect(autocreate=True)
Expand Down Expand Up @@ -214,24 +215,23 @@ def init(
# 注册清理函数
atexit.register(_clean_handler)
# ---------------------------------- 终端输出 ----------------------------------
if not cloud and not (project is None and organization is None):
swanlog.warning("The project or organization parameters are invalid in non-cloud mode")
if not cloud and not (project is None and entity is None):
swanlog.warning("The `project` or `entity` parameters are invalid in non-cloud mode")
swanlog.debug("SwanLab Runtime has initialized")
swanlog.debug("SwanLab will take over all the print information of the terminal from now on")
swanlog.info("Tracking run with swanlab version " + get_package_version())
swanlog.info("Run data will be saved locally in " + FONT.magenta(FONT.bold(formate_abs_path(run.settings.run_dir))))
# not cloud and swanlog.info("Experiment_name: " + FONT.yellow(run.settings.exp_name))
swanlog.info("Experiment_name: " + FONT.yellow(run.settings.exp_name))
not cloud and swanlog.info("Experiment_name: " + FONT.yellow(run.settings.exp_name))
# 云端版本有一些额外的信息展示
cloud and swanlog.info("👋 Hi " + FONT.bold(FONT.default(login_info.username)) + ", welcome to swanlab!")
# cloud and swanlog.info("Syncing run " + FONT.yellow(run.settings.exp_name) + " to the cloud")
cloud and swanlog.info("Syncing run " + FONT.yellow(run.settings.exp_name) + " to the cloud")
swanlog.info(
"🌟 Run `"
+ FONT.bold("swanlab watch -l {}".format(formate_abs_path(run.settings.swanlog_dir)))
+ "` to view SwanLab Experiment Dashboard locally"
)
project_url = get_host_web() + "/" + "{project_name}"
experiment_url = project_url + "/" + "123456"
project_url = get_host_web() + f"/@{http.groupname}/{http.projname}"
experiment_url = project_url + f"/runs/{http.exp_id}"
cloud and swanlog.info("🏠 View project at " + FONT.blue(FONT.underline(project_url)))
cloud and swanlog.info("🚀 View run at " + FONT.blue(FONT.underline(experiment_url)))
inited = True
Expand Down Expand Up @@ -365,7 +365,7 @@ async def _():
# 上传错误日志
if error is not None:
await upload_logs([e + '\n' for e in error.split('\n')], level='ERROR')
await asyncio.sleep(5)
await asyncio.sleep(1)

asyncio.run(FONT.loading("Waiting for uploading complete", _(), interval=0.5))
return
Expand Down