@@ -10,22 +10,21 @@ makeRLearner.regr.mob = function() {
10
10
makeNumericLearnerParam(id = " trim" , default = 0.1 , lower = 0 , upper = 1 ),
11
11
makeLogicalLearnerParam(id = " breakties" , default = FALSE ),
12
12
makeLogicalLearnerParam(id = " verbose" , default = FALSE , tunable = FALSE ),
13
- makeDiscreteLearnerParam(id = " model" , default = modeltools :: glinearModel ,
14
- values = list (glinearModel = modeltools :: glinearModel , linearModel = modeltools :: linearModel )),
13
+ makeDiscreteLearnerParam(id = " model" , default = " glinearModel" , values = list (" glinearModel" , " linearModel" )),
15
14
makeUntypedLearnerParam(id = " part.feats" ),
16
15
makeUntypedLearnerParam(id = " term.feats" )
17
16
),
18
17
par.vals = list (),
19
18
properties = c(" numerics" , " factors" , " weights" ),
20
19
name = " Model-based Recursive Partitioning Yielding a Tree with Fitted Models Associated with each Terminal Node" ,
21
20
short.name = " mob" ,
22
- callees = c(" mob" , " mob_control" , " glinearModel " , " linearModel " )
21
+ callees = c(" mob" , " mob_control" )
23
22
)
24
23
}
25
24
26
25
# ' @export
27
26
trainLearner.regr.mob = function (.learner , .task , .subset , .weights = NULL , alpha , bonferroni , minsplit ,
28
- trim , breakties , verbose , part.feats , term.feats , ... ) {
27
+ trim , breakties , verbose , part.feats , term.feats , model , ... ) {
29
28
30
29
cntrl = learnerArgsToControl(party :: mob_control , alpha , bonferroni , minsplit , trim , breakties , verbose )
31
30
@@ -42,11 +41,20 @@ trainLearner.regr.mob = function(.learner, .task, .subset, .weights = NULL, alph
42
41
target = getTaskTargetNames(.task )
43
42
f = as.formula(stri_paste(target , " ~" , collapse(term.feats , sep = " + " ), " |" , collapse(part.feats , sep = " + " ), sep = " " ))
44
43
45
- if (is.null(.weights )) {
46
- model = party :: mob(f , data = getTaskData(.task , .subset ), control = cntrl , ... )
47
- } else {
48
- model = party :: mob(f , data = getTaskData(.task , .subset ), control = cntrl , weights = .weights , ... )
44
+ args = list (f , data = getTaskData(.task , .subset ), control = cntrl , ... )
45
+ if (! is.null(.weights )) {
46
+ args $ weights = .weights
49
47
}
48
+ if (! missing(model )) {
49
+ if (is.character(model )) {
50
+ args $ model = getFromNamespace(model , " mda" )
51
+ } else {
52
+ args $ model = model # this allows to set the model if on.par.out.of.bounds is set to "warn" or "quiet"
53
+ }
54
+ }
55
+
56
+ model = do.call(party :: mob , args )
57
+
50
58
# sometimes mob fails to fit a model but does not signal an exception.
51
59
if (anyMissing(coef(model ))) {
52
60
stop(" Failed to fit party::mob. Some coefficients are estimated as NA" )
0 commit comments