Skip to content

Commit 7c3b7f1

Browse files
authored
feat: update xgboost integration (#1053)
1 parent fa18791 commit 7c3b7f1

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

swanlab/integration/xgboost.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010

1111

1212
class SwanLabCallback(xgb.callback.TrainingCallback):
13-
def __init__(self):
13+
def __init__(
14+
self,
15+
log_feature_importance: bool = True,
16+
importance_type: str = "gain",
17+
):
18+
self.log_feature_importance = log_feature_importance
19+
self.importance_type = importance_type
1420
# 如果没有注册过实验
1521
swanlab.config["FRAMEWORK"] = "xgboost"
1622
if swanlab.get_run() is None:
@@ -30,6 +36,9 @@ def before_training(self, model: Booster) -> Booster:
3036
def after_training(self, model: Booster) -> Booster:
3137
"""Run after training is finished."""
3238

39+
if self.log_feature_importance:
40+
self._log_feature_importance(model)
41+
3342
# Log the best score and best iteration
3443
if model.attr("best_score") is not None:
3544
swanlab.log(
@@ -52,3 +61,16 @@ def after_iteration(self, model: Booster, epoch: int, evals_log: dict) -> bool:
5261

5362
return False
5463

64+
def _log_feature_importance(self, model: Booster) -> None:
65+
fi = model.get_score(importance_type=self.importance_type)
66+
x = list(fi.keys())
67+
y = list(fi.values())
68+
y = [round(i, 2) for i in y] # 保留两位小数
69+
bar = swanlab.echarts.Bar()
70+
bar.add_xaxis(x)
71+
bar.add_yaxis("Importance", y)
72+
swanlab.log(
73+
{
74+
"Feature Importance": bar
75+
}
76+
)

0 commit comments

Comments
 (0)