Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Imports:
mlr3misc (>= 0.1.4),
paradox,
R6,
R.cache,
withr
Suggests:
ggplot2,
Expand Down
88 changes: 86 additions & 2 deletions R/Graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@
#' (and therefore their `$param_set$values`) and a hash of `$edges`.
#' * `keep_results` :: `logical(1)` \cr
#' Whether to store intermediate results in the [`PipeOp`]'s `$.result` slot, mostly for debugging purposes. Default `FALSE`.
#' * `cache` :: `logical(1)` \cr
#' Whether to cache individual [`PipeOp`]'s during "train" and "predict". Default `FALSE`.
#' Caching is performed using the [`R.cache`](R.cache::R.cache) package.
#' Caching can be disabled/enabled globally using `getOption("R.cache.enabled", TRUE)`.
#' By default, files are cached in `R.cache::getCacheRootPath()`.
#' For more information on how to set the cache path or retrieve cached items please consider
#' the [`R.cache`](R.cache::R.cache) documentation.
#' Caching can be fine-controlled for each [`PipeOp`] by adjusting individual [`PipeOp`]'s
#' `cache`, `cache_state` and `stochastic` fields.
#'
#' @section Methods:
#' * `ids(sorted = FALSE)` \cr
Expand Down Expand Up @@ -407,6 +416,13 @@ Graph = R6Class("Graph",
} else {
map(self$pipeops, "state")
}
},
cache = function(val) {
if (!missing(val)) {
private$.cache = assert_flag(val)
} else {
private$.cache
}
}
),

Expand All @@ -419,7 +435,8 @@ Graph = R6Class("Graph",
value
)
},
.param_set = NULL
.param_set = NULL,
.cache = FALSE
)
)

Expand Down Expand Up @@ -539,7 +556,7 @@ graph_reduce = function(self, input, fun, single_input) {
input = input_tbl$payload
names(input) = input_tbl$name

output = op[[fun]](input)
output = cached_pipeop_eval(self, op, fun, input)
if (self$keep_results) {
op$.result = output
}
Expand Down Expand Up @@ -609,3 +626,70 @@ predict.Graph = function(object, newdata, ...) {
}
result
}

# Cached train/predict of a PipeOp.
# 1) Caching of a PipeOp only performed if graph and po have `cache = TRUE`,
# i.e both the Graph AND the PipeOp want to be cached.
# 2) Additonally caching is only performed if 'train' or 'predict' is not stochastic
# for a given PipeOp. This can be obtained from `.$stochastic` and can be set
# for each PipeOp.
# 3) During training we have two options
# Each PipeOp stores whether it wants to do I. or II. in `.$cache_state`.
# I. Cache only state:
# This is possible if the train transform is the same as the predict transform
# and predict is comparatively cheap (i.e. filters).
# II. Cache state and output
# (All other cases)

cached_pipeop_eval = function(self, op, fun, input) {

if (self$cache && op$cache) {
require_namespaces("R.cache")
cache_key = list(map_chr(input, get_hash), op$hash)
if (fun == "train") {
if (fun %nin% op$stochastic) {
# Two options:
# I. cache state (can predict on train set using state during train)
# II. do not cache state () (if I. is not possible)
if (op$cache_state) {
# only cache state (I.)
R.cache::evalWithMemoization({
op[[fun]](input)
state = op$state
}, key = cache_key)
# Set state if PipeOp was cached (and "train" was therefore not called)
if (is.null(op$state) && fun == "train") op$state = state
# We call "predict" on train inputs, this avoids storing the outputs
# during training on disk.
# This is only done for pipeops where 'cache_state' is TRUE.
return(cached_pipeop_eval(self, op, "predict", input))
} else {
# Otherwise we cache state and input (II.)
R.cache::evalWithMemoization({
result = list(output = op[[fun]](input), state = op$state)
}, key = cache_key)
# Set state if PipeOp was cached before (and thus no state was set)
if (is.null(op$state) && fun == "train") op$state = result$state
return(result$output)
}
}
} else if (fun == "predict" && !op$cache_state) {
# during predict, only cache if cache_state is FALSE and op is not stochastic.
if (fun %nin% op$stochastic) {
R.cache::evalWithMemoization(
{output = op[[fun]](input)},
key = cache_key)
return(output)
}
}
}
# No caching fallback, anything where we do not run into conditions above
return(op[[fun]](input))
}

get_hash = function(x) {
hash = try(x$hash, silent = TRUE)
if (inherits(hash, "try-error") || is.null(hash))
hash = digest(x, algo = "xxhash64")
hash
}
41 changes: 39 additions & 2 deletions R/PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,21 @@
#' If the [`Graph`]'s `$keep_results` flag is set to `TRUE`, then the intermediate Results of `$train()` and `$predict()`
#' are saved to this slot, exactly as they are returned by these functions. This is mainly for debugging purposes
#' and done, if requested, by the [`Graph`] backend itself; it should *not* be done explicitly by `private$.train()` or `private$.predict()`.
#' * `cache` :: `logical(1)` \cr
#' Whether to cache the [`PipeOp`]'s state and or output during "train" and "predict". Defaults to `TRUE`.
#' See the `cache` field in [`Graph`] for more detailed information on caching, as well as `cache_state` and
#' `stochastic` below.
#' * `cache_state` :: `logical(1)` \cr
#' Whether the [`PipeOp`]s behaviour during training is equal to behaviour during prediction
#' (other then setting a state). In this case, only the [`PipeOp`]s state is cached.
#' This avoids caching possibly large intermediate results.
#' Defaults to `TRUE`.
#' * `stochastic` :: `character` \cr
#' Whether a [`PipeOp`] is stochastic during `"train"`, `"predict"`, or not at all: `character(0)`.
#' Defaults to `character(0)` (deterministic). Stochastic [`PipeOp`]s are not cached during the
#' respective phase.
#' A [`PipeOp`] is only cached if it is deterministic.
#'
#'
#' @section Methods:
#' * `train(input)`\cr
Expand Down Expand Up @@ -254,7 +269,6 @@ PipeOp = R6Class("PipeOp",
if (is_noop(self$state)) {
stopf("Pipeop %s got NO_OP during train but no NO_OP during predict.", self$id)
}

input = check_types(self, input, "input", "predict")
output = private$.predict(input)
output = check_types(self, output, "output", "predict")
Expand Down Expand Up @@ -296,6 +310,26 @@ PipeOp = R6Class("PipeOp",
hash = function() {
digest(list(class(self), self$id, self$param_set$values),
algo = "xxhash64")
},
cache = function(val) {
if (!missing(val)) {
private$.cache = assert_flag(val)
} else {
private$.cache
}
},
cache_state = function(val) {
if (!missing(val)) {
stop("cache_state is read-only!")
}
private$.cache_state
},
stochastic = function(val) {
if (!missing(val)) {
private$.stochastic = assert_subset(val, c("train", "predict"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be read-only and set during initialization?

} else {
private$.stochastic
}
}
),

Expand All @@ -318,7 +352,10 @@ PipeOp = R6Class("PipeOp",
.predict = function(input) stop("abstract"),
.param_set = NULL,
.param_set_source = NULL,
.id = NULL
.id = NULL,
.cache = TRUE,
.cache_state = TRUE,
.stochastic = character(0)
)
)

Expand Down
3 changes: 2 additions & 1 deletion R/PipeOpBranch.R
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ PipeOpBranch = R6Class("PipeOpBranch",
ret = named_list(self$output$name, NO_OP)
ret[[self$param_set$values$selection]] = inputs[[1]]
ret
}
},
.cache = FALSE
)
)

Expand Down
14 changes: 13 additions & 1 deletion R/PipeOpChunk.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@ PipeOpChunk = R6Class("PipeOpChunk",
)
}
),
active = list(
stochastic = function(val) {
if (!missing(val)) {
assert_subset(val, c("train", "predict"))
private$.stochastic = val
} else {
if (self$param_set$values$shuffle) return("train")
character(0)
}
}
),
private = list(
.train = function(inputs) {
self$state = list()
Expand All @@ -88,7 +99,8 @@ PipeOpChunk = R6Class("PipeOpChunk",
},
.predict = function(inputs) {
rep(inputs, self$outnum)
}
},
.cache = FALSE
)
)

Expand Down
5 changes: 4 additions & 1 deletion R/PipeOpClassBalancing.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ PipeOpClassBalancing = R6Class("PipeOpClassBalancing",
task_filter_ex(task, new_ids)
},

.predict_task = identity
.predict_task = identity,
.cache = FALSE,
.stochastic = "train",
.cache_state = FALSE
)
)

Expand Down
3 changes: 2 additions & 1 deletion R/PipeOpCopy.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ PipeOpCopy = R6Class("PipeOpCopy",
},
.predict = function(inputs) {
rep_len(inputs, self$outnum)
}
},
.cache = FALSE
)
)

Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpImputeHist.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ PipeOpImputeHist = R6Class("PipeOpImputeHist",
}
feature[is.na(feature)] = sampled
feature
}
},
.cache = FALSE,
.stochastic = c("train", "predict")
)
)

Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpImputeSample.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ PipeOpImputeSample = R6Class("PipeOpImputeSample",
feature[is.na(feature)] = sample(model, outlen, replace = TRUE)
}
feature
}
},
.cache = FALSE,
.stochastic = c("train", "predict")
)
)

Expand Down
3 changes: 2 additions & 1 deletion R/PipeOpNOP.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ PipeOpNOP = R6Class("PipeOpNOP",

.predict = function(inputs) {
inputs
}
},
.cache = FALSE
)
)

Expand Down
29 changes: 29 additions & 0 deletions R/PipeOpProxy.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,35 @@ PipeOpProxy = R6Class("PipeOpProxy",
)
}
),
active = list(
cache = function(val) {
if (!missing(val)) {
self$param_set$values$content$cache = assert_flag(val)
} else {
self$param_set$values$content$cache
}
},
stochastic = function(val) {
if (!missing(val)) {
assert_subset(val, c("train", "predict"))
if (inherits(self$param_set$values$content, "Graph"))
stop("'stochastic' not be set when content is a graph!")
else
self$param_set$values$content$stochastic = val
} else {
if (inherits(self$param_set$values$content, "Graph")) return(character(0))
self$param_set$values$content$stochastic
}
},
cache_state = function(val) {
if (!missing(val)) {
stop("cache_state is read-only!")
} else {
if (inherits(self$param_set$values$content, "Graph")) return(TRUE)
self$param_set$values$content$cache_state
}
}
),
private = list(
.param_set = NULL,
.param_set_source = NULL,
Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpSmote.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ PipeOpSmote = R6Class("PipeOpSmote",
}
setnames(st, "class", task$target_names)
task$rbind(st)
}
},
.cache = FALSE,
.stochastic = "train"
)
)

Expand Down
4 changes: 2 additions & 2 deletions R/PipeOpSubsample.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ PipeOpSubsample = R6Class("PipeOpSubsample",
self$state = list()
task_filter_ex(task, keep)
},

.predict_task = identity
.predict_task = identity,
.cache_state = FALSE
)
)

Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpThreshold.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ PipeOpThreshold = R6Class("PipeOpThreshold",
}

list(prd$set_threshold(thr))
}
},
.cache = FALSE,
.cache_state = FALSE
)
)

Expand Down
3 changes: 2 additions & 1 deletion R/PipeOpUnbranch.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ PipeOpUnbranch = R6Class("PipeOpUnbranch",
},
.predict = function(inputs) {
filter_noop(inputs)
}
},
.cache = FALSE
)
)

Expand Down
Loading