Skip to content

Commit a0616f1

Browse files
authored
add ability to set multi-valued booster params (#164)
1 parent 22a0239 commit a0616f1

File tree

3 files changed

+13
-1
lines changed

3 files changed

+13
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "XGBoost"
22
uuid = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"
3-
version = "2.2.2"
3+
version = "2.2.3"
44

55
[deps]
66
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/booster.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,15 @@ end
8181
setparam!(b::Booster, name::AbstractString, val) = setparam!(b, name, string(val))
8282
setparam!(b::Booster, name::Symbol, val) = setparam!(b, string(name), val)
8383

84+
setmultiparams!(b::Booster, name::Union{Symbol,AbstractString}, vals) = foreach(v -> setparam!(b, name, v), vals)
85+
86+
# the API for some parameters involves multiple separate calls to XGBoosterSetParam
87+
# multi methods for resolving ambiguities
88+
setparam!(b::Booster, name::Symbol, vals::AbstractVector) = setmultiparams!(b, name, vals)
89+
setparam!(b::Booster, name::AbstractString, vals::AbstractVector) = setmultiparams!(b, name, vals)
90+
setparam!(b::Booster, name::Symbol, vals::Tuple) = setmultiparams!(b, name, vals)
91+
setparam!(b::Booster, name::AbstractString, vals::Tuple) = setmultiparams!(b, name, vals)
92+
8493
"""
8594
setparams!(b::Booster; kw...)
8695

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ end
101101
watchlist=watchlist,
102102
η=1, max_depth=2,
103103
objective="binary:logistic",
104+
# check that we can set multiple param values
105+
eval_metric=["rmse", "rmsle"],
104106
)
105107
end
106108

@@ -171,6 +173,7 @@ end
171173
η=1.0, max_depth=2,
172174
objective="binary:logistic",
173175
watchlist=Dict(),
176+
eval_metric=("mae", "mape"),
174177
)
175178
preds = predict(bst, dtest)
176179
XGBoost.save(bst, model_file)

0 commit comments

Comments
 (0)