Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions swanlab/integration/xgboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import json
from typing import cast
import xgboost as xgb # type: ignore
from xgboost import Booster
import swanlab


class SwanLabCallback(xgb.callback.TrainingCallback):
def __init__(self):
# 如果没有注册过实验
if swanlab.get_run() is None:
raise RuntimeError("You must call swanlab.init() before SwanLabCallback(). 你必须在SwanLabCallback()之前,调用swanlab.init().")

def before_training(self, model: Booster) -> Booster:
"""Run before training is finished."""
# Update SwanLab config
config = model.save_config()
swanlab.config.update(json.loads(config))

return model

def after_training(self, model: Booster) -> Booster:
"""Run after training is finished."""

# Log the best score and best iteration
if model.attr("best_score") is not None:
swanlab.log(
{
"best_score": float(cast(str, model.attr("best_score"))),
"best_iteration": int(cast(str, model.attr("best_iteration"))),
}
)

return model

def after_iteration(self, model: Booster, epoch: int, evals_log: dict) -> bool:
"""Run after each iteration. Return True when training should stop."""
# Log metrics
for data, metric in evals_log.items():
for metric_name, log in metric.items():
swanlab.log({f"{data}-{metric_name}": log[-1]})

swanlab.log({"epoch": epoch})

return False

59 changes: 59 additions & 0 deletions test/integration/xgboost/train_xgboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import xgboost as xgb
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import swanlab
from swanlab.integration.xgboost import SwanLabCallback

# 初始化swanlab
swanlab.init(project="xgboost-breast-cancer", config={
"learning_rate": 0.1,
"max_depth": 3,
"subsample": 0.8,
"colsample_bytree": 0.8,
"num_round": 100
})

# 加载数据集
data = load_breast_cancer()
X = data.data
y = data.target

# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 转换为DMatrix格式,这是XGBoost的内部数据格式
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)

# 设置参数
params = {
'objective': 'binary:logistic', # 二分类任务
'max_depth': 3, # 树的最大深度
'eta': 0.1, # 学习率
'subsample': 0.8, # 样本采样比例
'colsample_bytree': 0.8, # 特征采样比例
'eval_metric': 'logloss' # 评估指标
}

# 训练模型
num_round = 100 # 迭代次数
bst = xgb.train(params, dtrain, num_round, evals=[(dtrain, 'train'), (dtest, 'test')], callbacks=[SwanLabCallback()])

# 进行预测
y_pred = bst.predict(dtest)
y_pred_binary = [round(value) for value in y_pred] # 将概率转换为二分类结果

# 评估模型
accuracy = accuracy_score(y_test, y_pred_binary)
print(f"Accuracy: {accuracy:.4f}")

# 打印分类报告
print("Classification Report:")
print(classification_report(y_test, y_pred_binary, target_names=data.target_names))

# 保存模型
bst.save_model('xgboost_model.model')

# 结束swanlab会话
swanlab.finish()