Skip to content

Commit 016b2d2

Browse files
authored
feat: integration xgboost (#745)
1 parent 8a692c7 commit 016b2d2

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

swanlab/integration/xgboost.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import json
2+
from typing import cast
3+
import xgboost as xgb # type: ignore
4+
from xgboost import Booster
5+
import swanlab
6+
7+
8+
class SwanLabCallback(xgb.callback.TrainingCallback):
9+
def __init__(self):
10+
# 如果没有注册过实验
11+
if swanlab.get_run() is None:
12+
raise RuntimeError("You must call swanlab.init() before SwanLabCallback(). 你必须在SwanLabCallback()之前,调用swanlab.init().")
13+
14+
def before_training(self, model: Booster) -> Booster:
15+
"""Run before training is finished."""
16+
# Update SwanLab config
17+
config = model.save_config()
18+
swanlab.config.update(json.loads(config))
19+
20+
return model
21+
22+
def after_training(self, model: Booster) -> Booster:
23+
"""Run after training is finished."""
24+
25+
# Log the best score and best iteration
26+
if model.attr("best_score") is not None:
27+
swanlab.log(
28+
{
29+
"best_score": float(cast(str, model.attr("best_score"))),
30+
"best_iteration": int(cast(str, model.attr("best_iteration"))),
31+
}
32+
)
33+
34+
return model
35+
36+
def after_iteration(self, model: Booster, epoch: int, evals_log: dict) -> bool:
37+
"""Run after each iteration. Return True when training should stop."""
38+
# Log metrics
39+
for data, metric in evals_log.items():
40+
for metric_name, log in metric.items():
41+
swanlab.log({f"{data}-{metric_name}": log[-1]})
42+
43+
swanlab.log({"epoch": epoch})
44+
45+
return False
46+
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import xgboost as xgb
2+
from sklearn.datasets import load_breast_cancer
3+
from sklearn.model_selection import train_test_split
4+
from sklearn.metrics import accuracy_score, classification_report
5+
import swanlab
6+
from swanlab.integration.xgboost import SwanLabCallback
7+
8+
# 初始化swanlab
9+
swanlab.init(project="xgboost-breast-cancer", config={
10+
"learning_rate": 0.1,
11+
"max_depth": 3,
12+
"subsample": 0.8,
13+
"colsample_bytree": 0.8,
14+
"num_round": 100
15+
})
16+
17+
# 加载数据集
18+
data = load_breast_cancer()
19+
X = data.data
20+
y = data.target
21+
22+
# 将数据集分为训练集和测试集
23+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
24+
25+
# 转换为DMatrix格式,这是XGBoost的内部数据格式
26+
dtrain = xgb.DMatrix(X_train, label=y_train)
27+
dtest = xgb.DMatrix(X_test, label=y_test)
28+
29+
# 设置参数
30+
params = {
31+
'objective': 'binary:logistic', # 二分类任务
32+
'max_depth': 3, # 树的最大深度
33+
'eta': 0.1, # 学习率
34+
'subsample': 0.8, # 样本采样比例
35+
'colsample_bytree': 0.8, # 特征采样比例
36+
'eval_metric': 'logloss' # 评估指标
37+
}
38+
39+
# 训练模型
40+
num_round = 100 # 迭代次数
41+
bst = xgb.train(params, dtrain, num_round, evals=[(dtrain, 'train'), (dtest, 'test')], callbacks=[SwanLabCallback()])
42+
43+
# 进行预测
44+
y_pred = bst.predict(dtest)
45+
y_pred_binary = [round(value) for value in y_pred] # 将概率转换为二分类结果
46+
47+
# 评估模型
48+
accuracy = accuracy_score(y_test, y_pred_binary)
49+
print(f"Accuracy: {accuracy:.4f}")
50+
51+
# 打印分类报告
52+
print("Classification Report:")
53+
print(classification_report(y_test, y_pred_binary, target_names=data.target_names))
54+
55+
# 保存模型
56+
bst.save_model('xgboost_model.model')
57+
58+
# 结束swanlab会话
59+
swanlab.finish()

0 commit comments

Comments
 (0)