Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 2 additions & 17 deletions swanlab/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
from typing import Any, Callable, List, Optional, Union

import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from swanlab.api.types import ApiResponse
from swanlab.core_python import auth
from swanlab.core_python import auth, create_session
from swanlab.log.log import SwanLog
from swanlab.package import get_package_version

_logger: Optional[SwanLog] = None

Expand Down Expand Up @@ -94,19 +91,7 @@ def sid_expired_at(self):
return datetime.strptime(self.__login_info.expired_at or "", "%Y-%m-%dT%H:%M:%S.%fZ")

def __init_session(self) -> requests.Session:
session = requests.Session()
session.mount(
prefix="https://",
adapter=HTTPAdapter(
max_retries=Retry(
total=3,
backoff_factor=0.1,
status_forcelist=[500, 502, 503, 504],
allowed_methods=(["GET", "POST", "PUT", "DELETE", "PATCH"])
)
)
)
session.headers["swanlab-sdk"] = get_package_version()
session = create_session()
session.cookies.update({"sid": self.__login_info.sid or ""})
return session

Expand Down
1 change: 1 addition & 0 deletions swanlab/core_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
"""

from .client import *
from .session import create_session
7 changes: 4 additions & 3 deletions swanlab/core_python/auth/providers/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from rich.status import Status
from rich.text import Text

from swanlab.core_python import auth
from swanlab.env import is_windows, is_interactive
from swanlab.error import ValidationError, APIKeyFormatError, KeyFileError
from swanlab.log import swanlog
from swanlab.package import get_setting_url, get_host_api, get_host_web, fmt_web_host, save_key as sk, get_key
from ...session import create_session


class LoginInfo:
Expand Down Expand Up @@ -102,7 +102,8 @@ def save(self):

def login_request(api_key: str, api_host: str, timeout: int = 20) -> requests.Response:
"""用户登录,请求后端接口完成验证"""
resp = requests.post(url=f"{api_host}/login/api_key", headers={"authorization": api_key}, timeout=timeout)
session = create_session()
resp = session.post(url=f"{api_host}/login/api_key", headers={"authorization": api_key}, timeout=timeout)
return resp


Expand Down Expand Up @@ -220,7 +221,7 @@ def create_login_info(save: bool = True):
raise KeyFileError(
"api key not configured (no-tty), call `swanlab.login(api_key=[your_api_key])` or set `swanlab.init(mode=\"local\")`."
)
return auth.terminal_login(key, save)
return terminal_login(key, save)


def _abort_tip(tp, _, __):
Expand Down
15 changes: 3 additions & 12 deletions swanlab/core_python/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from typing import Optional, Tuple, Dict, Union, List, AnyStr

import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from swanlab.error import NetworkError, ApiError
from swanlab.log import swanlog
Expand All @@ -20,6 +18,7 @@
from .cos import CosClient
from .model import ProjectInfo, ExperimentInfo
from .. import auth
from ..session import create_session


def decode_response(resp: requests.Response) -> Union[Dict, AnyStr, List]:
Expand Down Expand Up @@ -164,16 +163,7 @@ def __create_session(self):
创建会话,这将在HTTP类实例化时调用
添加了重试策略
"""
session = requests.Session()
retry = Retry(
total=3,
backoff_factor=0.1,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=frozenset(["GET", "POST", "PUT", "DELETE", "PATCH"]),
)
adapter = HTTPAdapter(max_retries=retry)
session.mount("https://", adapter)

session = create_session()
session.headers["swanlab-sdk"] = self.__version
session.cookies.update({"sid": self.__login_info.sid})

Expand Down Expand Up @@ -457,6 +447,7 @@ def wrapper(*args, **kwargs) -> Tuple[Optional[Union[dict, str]], Optional[Excep
__all__ = [
"get_client",
"reset_client",
"create_session",
"create_client",
"sync_error_handler",
"decode_response",
Expand Down
31 changes: 31 additions & 0 deletions swanlab/core_python/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
@author: cunyue
@file: session.py
@time: 2025/9/9 15:10
@description: 创建会话
"""

import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from swanlab.package import get_package_version


def create_session() -> requests.Session:
"""
创建一个带重试机制的会话
:return: requests.Session
"""
session = requests.Session()
retry = Retry(
total=5,
backoff_factor=0.5,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=frozenset(["GET", "POST", "PUT", "DELETE", "PATCH"]),
)
adapter = HTTPAdapter(max_retries=retry)
session.mount("https://", adapter)
session.mount("http://", adapter)
session.headers["swanlab-sdk"] = get_package_version()
return session
24 changes: 0 additions & 24 deletions test/unit/core_python/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import nanoid
import pytest
import requests_mock
import responses
from responses import registries

from swanlab.core_python import create_client, Client, CosClient, reset_client
from swanlab.core_python.auth import login_by_key
Expand Down Expand Up @@ -75,28 +73,6 @@ def test_decode_response():
assert data == "test"


@responses.activate(registry=registries.OrderedRegistry)
def test_retry():
"""
测试重试机制
"""
from swanlab.package import get_host_api

url = get_host_api() + "/retry"
rsp1 = responses.get(url, body="Error", status=500)
rsp2 = responses.get(url, body="Error", status=500)
rsp3 = responses.get(url, body="Error", status=500)
rsp4 = responses.get(url, body="OK", status=200)
with UseMockRunState() as run_state:
client = run_state.client
data, _ = client.get("/retry")
assert data == "OK"
assert rsp1.call_count == 1
assert rsp2.call_count == 1
assert rsp3.call_count == 1
assert rsp4.call_count == 1


@pytest.mark.skipif(is_skip_cloud_test, reason="skip cloud test")
class TestCosSuite:
http: Client = None
Expand Down
103 changes: 103 additions & 0 deletions test/unit/core_python/test_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
@author: cunyue
@file: test_session.py
@time: 2025/9/9 15:12
@description: $END$
"""

import pytest
import responses
from responses import registries

from swanlab.core_python import create_session
from swanlab.package import get_package_version


@pytest.mark.parametrize("url", ["https://api.example.com/retry", "http://api.example.com/retry"])
@responses.activate(registry=registries.OrderedRegistry)
def test_retry(url):
"""
测试重试机制
"""

[responses.add(responses.GET, url, body="Error", status=500) for _ in range(5)]
responses.add(responses.GET, url, body="Success", status=200)
s = create_session()
resp = s.get(url)
assert resp.text == "Success"
assert len(responses.calls) == 6


@responses.activate(registry=registries.OrderedRegistry)
def test_session_headers():
"""
测试会话是否包含正确的自定义请求头
"""
# 1. 准备测试数据
test_url = "https://api.example.com/test"
expected_sdk_version = get_package_version()

# 2. 模拟响应 - 捕获请求头
captured_headers = {}

def request_callback(request):
# 捕获所有请求头
nonlocal captured_headers
captured_headers = dict(request.headers)
return (200, {}, "OK")

responses.add_callback(responses.GET, test_url, callback=request_callback)

# 3. 创建会话并发送请求
session = create_session()
response = session.get(test_url)

# 4. 验证
assert response.status_code == 200

# 验证自定义头存在且值正确
assert "swanlab-sdk" in captured_headers
assert captured_headers["swanlab-sdk"] == expected_sdk_version

# 验证User-Agent等默认头也存在(可选)
assert "User-Agent" in captured_headers

# 打印所有捕获的请求头(调试用)
print("\n捕获的请求头:", captured_headers)


@responses.activate(registry=registries.OrderedRegistry)
def test_header_merging():
"""
测试请求级别headers与会话级别headers的合并
"""
test_url = "https://api.example.com/merge"
custom_header = {"X-Custom-Request-Header": "test-value"}

captured_headers = {}

def request_callback(request):
nonlocal captured_headers
captured_headers = dict(request.headers)
return (200, {}, "OK")

responses.add_callback(responses.GET, test_url, callback=request_callback)

# 创建会话(自带swanlab-sdk头)
session = create_session()

# 发送带额外请求头的请求
response = session.get(test_url, headers=custom_header)

# 验证
assert response.status_code == 200

# 验证会话头依然存在
assert "swanlab-sdk" in captured_headers

# 验证请求级别头已添加
assert "X-Custom-Request-Header" in captured_headers
assert captured_headers["X-Custom-Request-Header"] == "test-value"

# 验证合并而非覆盖(两个头都存在)
assert len(captured_headers) >= 2