Skip to content

Commit b202083

Browse files
committed
update:增加模型预热机制,解决第一次推理慢的问题
1 parent 236095b commit b202083

15 files changed

+412
-106
lines changed

app/api/dependencies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from fastapi.security import HTTPBearer
33
from typing import Annotated
44
from ..core.security import verify_token
5-
from ..core.logging import get_logger
5+
from ..core.logger import get_logger
66

77
logger = get_logger(__name__)
88

app/api/v1/health.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from fastapi import APIRouter, HTTPException, Query
22
import time
33
from ...services.voiceprint_service import voiceprint_service
4-
from ...core.logging import get_logger
4+
from ...core.logger import get_logger
55
from ...core.config import settings
66

77
logger = get_logger(__name__)
@@ -31,7 +31,7 @@ async def health_check(
3131
HTTPException: 当密钥不正确时返回401错误
3232
"""
3333
start_time = time.time()
34-
logger.info("收到健康检查请求")
34+
logger.start("健康检查请求")
3535

3636
# 验证密钥
3737
key_check_start = time.time()
@@ -49,9 +49,9 @@ async def health_check(
4949
logger.info(f"声纹统计信息获取完成,总数: {count},耗时: {count_time:.3f}秒")
5050

5151
total_time = time.time() - start_time
52-
logger.info(f"健康检查请求完成,总耗时: {total_time:.3f}秒")
52+
logger.complete("健康检查请求", total_time)
5353
return {"total_voiceprints": count, "status": "healthy"}
5454
except Exception as e:
5555
total_time = time.time() - start_time
56-
logger.error(f"获取统计信息异常,总耗时: {total_time:.3f}秒,错误: {e}")
56+
logger.fail(f"获取统计信息异常,总耗时: {total_time:.3f}秒,错误: {e}")
5757
raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}")

app/api/v1/voiceprint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ...models.voiceprint import VoiceprintRegisterResponse, VoiceprintIdentifyResponse
66
from ...services.voiceprint_service import voiceprint_service
77
from ...api.dependencies import AuthorizationToken
8-
from ...core.logging import get_logger
8+
from ...core.logger import get_logger
99

1010
# 创建安全模式
1111
security = HTTPBearer(description="接口令牌")
@@ -57,7 +57,7 @@ async def register_voiceprint(
5757
except HTTPException:
5858
raise
5959
except Exception as e:
60-
logger.error(f"声纹注册异常: {e}")
60+
logger.fail(f"声纹注册异常: {e}")
6161
raise HTTPException(status_code=500, detail=f"声纹注册失败: {str(e)}")
6262

6363

app/application.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
from fastapi import FastAPI
1+
from fastapi import FastAPI, Request
22
from fastapi.security import HTTPBearer
33
from fastapi.middleware.cors import CORSMiddleware
44
from fastapi.responses import RedirectResponse
55
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
66
from fastapi.openapi.utils import get_openapi
77

88
from .api.v1.api import api_router
9+
from loguru import logger
10+
from .core.version import VERSION
11+
import time
912

1013

1114
def create_app() -> FastAPI:
@@ -18,7 +21,7 @@ def create_app() -> FastAPI:
1821
app = FastAPI(
1922
title="3D-Speaker 声纹识别API",
2023
description="基于3D-Speaker的声纹注册与识别服务",
21-
version="2.0.0",
24+
version=VERSION,
2225
docs_url=None, # 禁用默认的docs路径
2326
redoc_url=None, # 禁用默认的redoc路径
2427
)

app/core/logger.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
"""
2+
日志管理模块 - 统一的日志配置和记录器
3+
"""
4+
5+
import os
6+
import sys
7+
import logging
8+
import warnings
9+
from typing import Optional
10+
from loguru import logger
11+
from .config import settings
12+
from .version import VERSION
13+
14+
# 移除默认的loguru处理器
15+
logger.remove()
16+
17+
18+
class LoggingHandler(logging.Handler):
19+
"""拦截logging日志并转发到loguru"""
20+
21+
def emit(self, record):
22+
# 获取对应的loguru level
23+
try:
24+
level = logger.level(record.levelname).name
25+
except ValueError:
26+
level = record.levelno
27+
28+
# 获取logger名称
29+
logger_name = record.name
30+
if logger_name.startswith("uvicorn"):
31+
logger_name = "uvicorn"
32+
elif logger_name.startswith("fastapi"):
33+
logger_name = "fastapi"
34+
elif logger_name.startswith("modelscope"):
35+
logger_name = "modelscope"
36+
elif logger_name.startswith("torch"):
37+
logger_name = "torch"
38+
elif logger_name.startswith("pydantic"):
39+
logger_name = "pydantic"
40+
elif logger_name.startswith("app."):
41+
logger_name = record.name
42+
43+
# 转发到loguru
44+
logger.opt(exception=record.exc_info).bind(
45+
name=logger_name, version=VERSION
46+
).log(level, record.getMessage())
47+
48+
49+
class WarningHandler:
50+
"""拦截warnings并转发到loguru"""
51+
52+
def __init__(self):
53+
self.original_showwarning = warnings.showwarning
54+
55+
def showwarning(self, message, category, filename, lineno, file=None, line=None):
56+
# 转发到loguru
57+
logger.bind(name="warnings", version=VERSION).warning(
58+
f"{category.__name__}: {message}"
59+
)
60+
61+
62+
class StderrHandler:
63+
"""拦截stderr输出"""
64+
65+
def __init__(self):
66+
self.original_stderr = sys.stderr
67+
68+
def write(self, text):
69+
if text.strip(): # 忽略空行
70+
# 尝试解析uvicorn格式的日志
71+
if (
72+
text.startswith("INFO:")
73+
or text.startswith("WARNING:")
74+
or text.startswith("ERROR:")
75+
):
76+
# 这是uvicorn格式,转换为我们的格式
77+
parts = text.strip().split(":", 1)
78+
if len(parts) == 2:
79+
level = parts[0].strip()
80+
message = parts[1].strip()
81+
logger.bind(name="uvicorn", version=VERSION).info(message)
82+
else:
83+
# 其他stderr输出
84+
logger.bind(name="stderr", version=VERSION).warning(text.strip())
85+
86+
def flush(self):
87+
self.original_stderr.flush()
88+
89+
90+
def setup_logging(level: Optional[str] = None) -> None:
91+
"""
92+
设置应用日志配置,使用loguru实现优雅的分段颜色显示
93+
94+
格式: 时间[青色] 版本号[蓝色] 模块[灰色]-级别[彩色]-消息[绿色]
95+
示例: 250705 13:33:23[0.6.2][core.utils.modules_initialize]-INFO-初始化组件: intent成功
96+
97+
Args:
98+
level: 日志级别
99+
"""
100+
# 获取配置
101+
log_level = level or settings.logging.get("level", "INFO")
102+
103+
# 确保日志目录存在
104+
log_dir = "logs"
105+
os.makedirs(log_dir, exist_ok=True)
106+
107+
# 控制台输出格式 - 分段颜色显示
108+
console_format = (
109+
"<cyan>{time:YYMMDD HH:mm:ss}</cyan>"
110+
"<blue>[{extra[version]}]</blue>"
111+
"<light-black>[{name}]</light-black>-"
112+
"<level>{level}</level>-"
113+
"<green>{message}</green>"
114+
)
115+
116+
# 文件输出格式 - 无颜色,保持相同格式
117+
file_format = (
118+
"{time:YYMMDD HH:mm:ss}" "[{extra[version]}]" "[{name}]-" "{level}-" "{message}"
119+
)
120+
121+
# 添加控制台处理器
122+
logger.add(
123+
sys.stdout,
124+
format=console_format,
125+
level=log_level,
126+
colorize=True,
127+
backtrace=True,
128+
diagnose=True,
129+
enqueue=True,
130+
)
131+
132+
# 添加文件处理器
133+
logger.add(
134+
os.path.join(log_dir, "voiceprint_api.log"),
135+
format=file_format,
136+
level=log_level,
137+
rotation="10 MB",
138+
retention="7 days",
139+
compression="gz",
140+
encoding="utf-8",
141+
backtrace=True,
142+
diagnose=True,
143+
enqueue=True,
144+
)
145+
146+
# 拦截所有logging日志
147+
# 1. 移除root logger的所有handler
148+
for handler in logging.root.handlers[:]:
149+
logging.root.removeHandler(handler)
150+
151+
# 2. 设置root logger只使用我们的handler
152+
logging.basicConfig(handlers=[LoggingHandler()], level=0, force=True)
153+
154+
# 3. 强制替换所有已存在的logger的handler
155+
intercept_handler = LoggingHandler()
156+
for name in logging.root.manager.loggerDict:
157+
log = logging.getLogger(name)
158+
# 移除所有现有handler
159+
for handler in log.handlers[:]:
160+
log.removeHandler(handler)
161+
# 添加我们的handler
162+
log.addHandler(intercept_handler)
163+
# 设置propagate为False,防止重复输出
164+
log.propagate = False
165+
166+
# 设置第三方库的日志级别
167+
logger.bind(version=VERSION).info(f"日志系统初始化完成,级别: {log_level}")
168+
169+
170+
class Logger:
171+
"""优雅的日志记录器 - 基于loguru"""
172+
173+
def __init__(self, name: str):
174+
self._name = name
175+
# 直接使用loguru的logger,绑定模块名和版本
176+
self._logger = logger.bind(name=name, version=VERSION)
177+
178+
def debug(self, message: str, *args, **kwargs):
179+
"""调试日志"""
180+
self._logger.debug(message, *args, **kwargs)
181+
182+
def info(self, message: str, *args, **kwargs):
183+
"""信息日志"""
184+
self._logger.info(message, *args, **kwargs)
185+
186+
def warning(self, message: str, *args, **kwargs):
187+
"""警告日志"""
188+
self._logger.warning(message, *args, **kwargs)
189+
190+
def error(self, message: str, *args, **kwargs):
191+
"""错误日志"""
192+
self._logger.error(message, *args, **kwargs)
193+
194+
def critical(self, message: str, *args, **kwargs):
195+
"""严重错误日志"""
196+
self._logger.critical(message, *args, **kwargs)
197+
198+
def success(self, message: str, *args, **kwargs):
199+
"""成功日志(使用INFO级别但语义更清晰)"""
200+
self._logger.info(f"✅ {message}", *args, **kwargs)
201+
202+
def fail(self, message: str, *args, **kwargs):
203+
"""失败日志(使用ERROR级别但语义更清晰)"""
204+
self._logger.error(f"❌ {message}", *args, **kwargs)
205+
206+
def start(self, operation: str, *args, **kwargs):
207+
"""开始操作日志"""
208+
self._logger.info(f"🚀 开始: {operation}", *args, **kwargs)
209+
210+
def complete(
211+
self, operation: str, duration: Optional[float] = None, *args, **kwargs
212+
):
213+
"""完成操作日志"""
214+
if duration is not None:
215+
self._logger.info(
216+
f"✅ 完成: {operation} (耗时: {duration:.3f}秒)", *args, **kwargs
217+
)
218+
else:
219+
self._logger.info(f"✅ 完成: {operation}", *args, **kwargs)
220+
221+
def init_component(
222+
self, component_name: str, status: str = "成功", *args, **kwargs
223+
):
224+
"""组件初始化日志"""
225+
if status.lower() in ["成功", "success", "ok"]:
226+
self._logger.info(
227+
f"🔧 初始化组件: {component_name} {status}", *args, **kwargs
228+
)
229+
else:
230+
self._logger.error(
231+
f"🔧 初始化组件: {component_name} {status}", *args, **kwargs
232+
)
233+
234+
235+
def get_logger(name: str) -> Logger:
236+
"""
237+
获取优雅的日志记录器
238+
239+
Args:
240+
name: 日志记录器名称
241+
242+
Returns:
243+
Logger: 日志记录器实例
244+
"""
245+
return Logger(name)
246+
247+
248+
# 便捷函数
249+
def log_success(message: str, logger_name: str = __name__):
250+
"""记录成功日志"""
251+
get_logger(logger_name).success(message)
252+
253+
254+
def log_fail(message: str, logger_name: str = __name__):
255+
"""记录失败日志"""
256+
get_logger(logger_name).fail(message)
257+
258+
259+
def log_start(operation: str, logger_name: str = __name__):
260+
"""记录开始操作"""
261+
get_logger(logger_name).start(operation)
262+
263+
264+
def log_complete(
265+
operation: str, duration: Optional[float] = None, logger_name: str = __name__
266+
):
267+
"""记录完成操作"""
268+
get_logger(logger_name).complete(operation, duration)
269+
270+
271+
def log_init_component(
272+
component_name: str, status: str = "成功", logger_name: str = __name__
273+
):
274+
"""记录组件初始化"""
275+
get_logger(logger_name).init_component(component_name, status)

0 commit comments

Comments
 (0)