Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions swanlab/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ def get_experiment_url(username: str, projname: str, expid: str) -> str:
# ---------------------------------- 登录相关 ----------------------------------


def get_nrc_path() -> str:
"""
获取netrc文件路径
"""
return os.path.join(get_save_dir(), ".netrc")


def get_key():
"""使用标准netrc库解析token文件,获取token
:raise KeyFileError: 文件不存在或者host不存在
Expand All @@ -105,8 +112,7 @@ def get_key():
env_key = os.getenv(SwanLabEnv.API_KEY.value)
if env_key is not None:
return env_key
path = os.path.join(get_save_dir(), ".netrc")
host = get_host_api()
path, host = get_nrc_path(), get_host_api()
if not os.path.exists(path):
raise KeyFileError("The file does not exist")
nrc = netrc.netrc(path)
Expand All @@ -126,14 +132,19 @@ def save_key(username: str, password: str, host: str = None):
"""
if host is None:
host = get_host_api()
path = os.path.join(get_save_dir(), ".netrc")
path = get_nrc_path()
if not os.path.exists(path):
with open(path, "w") as f:
f.write("")
nrc = netrc.netrc(path)
nrc.hosts[host] = (username, None, password)
with open(path, "w") as f:
f.write(nrc.__repr__())
new_info = (username, "", password)
# 避免重复的写
info = nrc.authenticators(host)
if info != new_info:
# 同时只允许存在一个host: https://github.com/SwanHubX/SwanLab/issues/797
nrc.hosts = {host: new_info}
with open(path, "w") as f:
f.write(nrc.__repr__())


class LoginCheckContext:
Expand Down Expand Up @@ -173,7 +184,6 @@ def is_login() -> bool:
"""判断是否已经登录,与当前的host相关
如果环境变量中有api key,则认为已经登录
但不会检查key的有效性
FIXME 目前存在一些bug,此函数只能在未登录前判断,登录后判断会有一些bug
:return: 是否已经登录
"""
with LoginCheckContext() as checker:
Expand Down
41 changes: 41 additions & 0 deletions test/unit/test_package.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import netrc
import os
import time

import nanoid
import pytest
Expand Down Expand Up @@ -148,6 +149,46 @@ def test_ok(self):
host = P.get_host_api()
P.save_key("user", password, host=host)
assert self.get_key(path, host) == password
# 在保存一次,保证只存在一个host
new_host = nanoid.generate()
P.save_key("user", password, host=new_host)
nrc = netrc.netrc(path)
assert len(nrc.hosts) == 1
assert nrc.authenticators(new_host) is not None

def test_duplicate(self):
"""
测试重复保存,此时会略过
"""
path = os.path.join(get_save_dir(), ".netrc")
password = nanoid.generate()
host = P.get_host_api()
P.save_key("user", password, host=host)
change_time = os.path.getmtime(path)
assert self.get_key(path, host) == password
P.save_key("user", password, host=host)
assert self.get_key(path, host) == password
assert os.path.getmtime(path) == change_time
time.sleep(0.1)
# 再次保存,但是账号不同
new_password = nanoid.generate()
P.save_key("user2", new_password, host=host)
assert self.get_key(path, host) == new_password
assert os.path.getmtime(path) != change_time
time.sleep(0.1)
# 再次保存,但是host不同
new_host = nanoid.generate()
P.save_key("user", new_password, host=new_host)
nrc = netrc.netrc(path)
assert len(nrc.hosts) == 1
assert nrc.authenticators(new_host) is not None
time.sleep(0.1)
# 再次保存,但是密码不同
new_password = nanoid.generate()
P.save_key("user", new_password, host=new_host)
nrc = netrc.netrc(path)
assert len(nrc.hosts) == 1
assert nrc.authenticators(new_host)[2] == new_password


class TestIsLogin:
Expand Down
Loading