Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
swankit==0.1.3
cos-python-sdk-v5
urllib3>=1.26.0
requests>=2.25.0
setuptools
click
pyyaml
psutil>=5.0.0
pynvml
boto3>=1.35.49
botocore
49 changes: 26 additions & 23 deletions swanlab/api/cos.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@
@File: cos.py
@IDE: pycharm
@Description:
tencent cos
cloud object storage
"""
# noinspection PyPackageRequirements
from qcloud_cos import CosConfig
# noinspection PyPackageRequirements
from qcloud_cos import CosS3Client
# noinspection PyPackageRequirements
from qcloud_cos.cos_threadpool import SimpleThreadPool
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime, timedelta
from typing import List, Dict, Union
from typing import List

import boto3
from botocore.config import Config as BotocoreConfig

from swanlab.data.modules import MediaBuffer
from swanlab.log import swanlog

Expand All @@ -30,14 +29,20 @@ def __init__(self, data):
self.__prefix = data["prefix"]
self.__bucket = data["bucket"]
credentials = data["credentials"]
config = CosConfig(
Region=data["region"],
SecretId=credentials['tmpSecretId'],
SecretKey=credentials['tmpSecretKey'],
Token=credentials['sessionToken'],
Scheme='https'

# 往期版本适配
# TODO 后续删除
endpoint_url = f"https://cos.{data['region']}.myqcloud.com" if "endpoint" not in data else data["endpoint"]

self.__client = boto3.client(
's3',
endpoint_url=endpoint_url,
api_version='2006-03-01',
aws_access_key_id=credentials['tmpSecretId'],
aws_secret_access_key=credentials['tmpSecretKey'],
aws_session_token=credentials['sessionToken'],
config=BotocoreConfig(signature_version="s3", s3={'addressing_style': 'virtual'}),
)
self.__client = CosS3Client(config)

def upload(self, buffer: MediaBuffer):
"""
Expand All @@ -52,28 +57,26 @@ def upload(self, buffer: MediaBuffer):
Bucket=self.__bucket,
Key=key,
Body=buffer.getvalue(),
EnableMD5=False,
# 一年
CacheControl="max-age=31536000",
)
except Exception as e:
swanlog.error("Upload error: {}".format(e))

def upload_files(self, buffers: List[MediaBuffer]) -> Dict[str, Union[bool, List]]:
def upload_files(self, buffers: List[MediaBuffer]):
"""
批量上传文件,keys和local_paths的长度应该相等
:param buffers: 本地文件的二进制对象集合
"""
pool = SimpleThreadPool()
for buffer in buffers:
self.upload(buffer)
pool.wait_completion()
result = pool.get_result()
return result
with ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(self.upload, buffer) for buffer in buffers]
for future in futures:
future.result()

@property
def should_refresh(self):
# cos传递的是北京时间,需要添加8小时
# FIXME Use timezone-aware objects to represent datetimes in UTC; e.g. by calling .now(datetime.UTC)
now = datetime.utcnow() + timedelta(hours=8)
# 过期时间减去当前时间小于刷新时间,需要注意为负数的情况
if self.__expired_time < now:
Expand Down
Loading