Skip to content

Commit 4413b8a

Browse files
pat-svrodriguezf
authored andcommitted
Move pkg function calls within ParamSets to helper file (mlr-org#2730)
* Refactor function calls from packages (`<pkg::fun>`) within ParamSets (mlr-org#2730)
1 parent 5149b04 commit 4413b8a

File tree

2 files changed

+28
-12
lines changed

2 files changed

+28
-12
lines changed

R/RLearner_classif_mda.R

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ makeRLearner.classif.mda = function() {
1111
makeIntegerLearnerParam(id = "dimension", lower = 1L),
1212
makeNumericLearnerParam(id = "eps", default = .Machine$double.eps, lower = 0),
1313
makeIntegerLearnerParam(id = "iter", default = 5L, lower = 1L),
14-
makeDiscreteLearnerParam(id = "method", default = mda::polyreg,
15-
values = list(polyreg = mda::polyreg, mars = mda::mars, bruto = mda::bruto, gen.ridge = mda::gen.ridge)),
14+
makeDiscreteLearnerParam(id = "method", default = "polyreg", values = list("polyreg", "mars", "bruto", "gen.ridge")),
1615
makeLogicalLearnerParam(id = "keep.fitted", default = TRUE),
1716
makeLogicalLearnerParam(id = "trace", default = FALSE, tunable = FALSE),
1817
makeDiscreteLearnerParam(id = "start.method", default = "kmeans", values = c("kmeans", "lvq")),
@@ -29,9 +28,18 @@ makeRLearner.classif.mda = function() {
2928
}
3029

3130
#' @export
32-
trainLearner.classif.mda = function(.learner, .task, .subset, .weights = NULL, ...) {
31+
trainLearner.classif.mda = function(.learner, .task, .subset, .weights = NULL, method, ...) {
3332
f = getTaskFormula(.task)
34-
mda::mda(f, data = getTaskData(.task, .subset), ...)
33+
args = list(...)
34+
if (!missing(method)) {
35+
if (is.character(methods)) {
36+
args$method = getFromNamespace(method, "mda")
37+
} else {
38+
args$method = method #this allows to set the method if on.par.out.of.bounds is set to "warn" or "quiet"
39+
}
40+
}
41+
args = c(list(f, data = getTaskData(.task, .subset)), args)
42+
do.call(mda::mda, args)
3543
}
3644

3745
#' @export

R/RLearner_regr_mob.R

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,21 @@ makeRLearner.regr.mob = function() {
1010
makeNumericLearnerParam(id = "trim", default = 0.1, lower = 0, upper = 1),
1111
makeLogicalLearnerParam(id = "breakties", default = FALSE),
1212
makeLogicalLearnerParam(id = "verbose", default = FALSE, tunable = FALSE),
13-
makeDiscreteLearnerParam(id = "model", default = modeltools::glinearModel,
14-
values = list(glinearModel = modeltools::glinearModel, linearModel = modeltools::linearModel)),
13+
makeDiscreteLearnerParam(id = "model", default = "glinearModel", values = list("glinearModel", "linearModel")),
1514
makeUntypedLearnerParam(id = "part.feats"),
1615
makeUntypedLearnerParam(id = "term.feats")
1716
),
1817
par.vals = list(),
1918
properties = c("numerics", "factors", "weights"),
2019
name = "Model-based Recursive Partitioning Yielding a Tree with Fitted Models Associated with each Terminal Node",
2120
short.name = "mob",
22-
callees = c("mob", "mob_control", "glinearModel", "linearModel")
21+
callees = c("mob", "mob_control")
2322
)
2423
}
2524

2625
#' @export
2726
trainLearner.regr.mob = function(.learner, .task, .subset, .weights = NULL, alpha, bonferroni, minsplit,
28-
trim, breakties, verbose, part.feats, term.feats, ...) {
27+
trim, breakties, verbose, part.feats, term.feats, model, ...) {
2928

3029
cntrl = learnerArgsToControl(party::mob_control, alpha, bonferroni, minsplit, trim, breakties, verbose)
3130

@@ -42,11 +41,20 @@ trainLearner.regr.mob = function(.learner, .task, .subset, .weights = NULL, alph
4241
target = getTaskTargetNames(.task)
4342
f = as.formula(stri_paste(target, "~", collapse(term.feats, sep = " + "), "|", collapse(part.feats, sep = " + "), sep = " "))
4443

45-
if (is.null(.weights)) {
46-
model = party::mob(f, data = getTaskData(.task, .subset), control = cntrl, ...)
47-
} else {
48-
model = party::mob(f, data = getTaskData(.task, .subset), control = cntrl, weights = .weights, ...)
44+
args = list(f, data = getTaskData(.task, .subset), control = cntrl, ...)
45+
if (!is.null(.weights)) {
46+
args$weights = .weights
4947
}
48+
if (!missing(model)) {
49+
if (is.character(model)) {
50+
args$model = getFromNamespace(model, "mda")
51+
} else {
52+
args$model = model #this allows to set the model if on.par.out.of.bounds is set to "warn" or "quiet"
53+
}
54+
}
55+
56+
model = do.call(party::mob, args)
57+
5058
# sometimes mob fails to fit a model but does not signal an exception.
5159
if (anyMissing(coef(model))) {
5260
stop("Failed to fit party::mob. Some coefficients are estimated as NA")

0 commit comments

Comments
 (0)