Skip to content

Commit 6408b80

Browse files
berndbischlvrodriguezf
authored andcommitted
Make sure that optimized hyperparameters are applied in the performance level of a CV (mlr-org#2479)
1 parent bb46079 commit 6408b80

File tree

3 files changed

+44
-1
lines changed

3 files changed

+44
-1
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ In this case, the package name is omitted.
5858
* regr.h2o.gbm: Various parameters added, `"h2o.use.data.table" = TRUE` is now the default (@j-hartshorn, #2508)
5959
* h2o learners now support getting feature importance (@markusdumke, #2434)
6060

61+
## learners - fixes
62+
* In some cases the optimized hyperparameters were not applied in the performance level of a nested CV (@berndbischl, #2479)
63+
6164
## featSel - general
6265
* The FeatSelResult object now contains an additional slot `x.bit.names` that stores the optimal bits
6366
* The slot `x` now always contains the real feature names and not the bit.names

R/TuneWrapper.R

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,17 @@ trainLearner.TuneWrapper = function(.learner, .task, .subset = NULL, ...) {
7474

7575
#' @export
7676
predictLearner.TuneWrapper = function(.learner, .model, .newdata, ...) {
77+
# setHyperPars just set for completivnes, Actual hyperparams are in ...
7778
lrn = setHyperPars(.learner$next.learner, par.vals = .model$learner.model$opt.result$x)
78-
predictLearner(lrn, .model$learner.model$next.model, .newdata, ...)
79+
arglist = list(.learner = lrn, .model = .model$learner.model$next.model, .newdata = .newdata)
80+
arglist = insert(arglist, list(...))
81+
82+
# get x from opt result and only select those that are used for predition
83+
opt.x = .model$learner.model$opt.result$x
84+
ps = getParamSet(lrn)
85+
ns = Filter(function(x) ps$pars[[x]]$when %in% c("both", "predict"), getParamIds(ps))
86+
arglist = insert(arglist, opt.x[ns])
87+
do.call(predictLearner, arglist)
7988
}
8089

8190
#' @export

tests/testthat/test_base_TuneWrapper.R

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,34 @@ test_that("TuneWrapper with glmnet (#958)", {
120120
expect_error(pred, NA)
121121
})
122122

123+
test_that("TuneWrapper respects train parameters (#2472)", {
124+
125+
# make task with only 0 as y
126+
tsk = makeRegrTask("dummy", data = data.frame(y = rep(0L, 100), x = rep(1L, 100)), target = "y")
127+
128+
ps = makeParamSet(
129+
makeNumericLearnerParam("p1", when = "train", lower = 0, upper = 10),
130+
makeNumericLearnerParam("p2", when = "predict", lower = 0, upper = 10),
131+
makeNumericLearnerParam("p3", when = "both", lower = 0, upper = 10)
132+
)
133+
134+
lrn = makeLearner("regr.__mlrmocklearners__4", predict.type = "response", p1 = 10, p2 = 10, p3 = 10)
135+
# prediction of this learner is always
136+
# train_part = p1 + p3
137+
# y = train_part + p2 + p3
138+
# therefore p1 = p2 = p3 = 0 is the optimal setting
139+
# we set params to bad values p1 = p2 = p3 = 10, meaning |y_hat-y| would be 40
140+
141+
lrn2 = makeTuneWrapper(lrn, resampling = makeResampleDesc("Holdout"),
142+
par.set = ps,
143+
control = makeTuneControlGrid(resolution = 2L))
144+
mod = train(lrn2, tsk)
145+
# we expect that the optimal parameters are found by the grid search.
146+
expect_equal(mod$learner.model$opt.result$x, list(p1 = 0, p2 = 0, p3 = 0))
147+
expect_true(mod$learner.model$opt.result$y == 0)
148+
pred = predict(mod, tsk)
149+
# we expect that the optimal parameter are also applied for prediction and therefore y_hat = p1+p2+p3+p3 should be 0
150+
expect_true(all(pred$data$response == 0))
151+
})
152+
153+

0 commit comments

Comments
 (0)