Skip to content

Commit

Permalink
perf: speed up surrogate predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
sumny committed Nov 19, 2024
1 parent 012b60c commit ce2e16f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
12 changes: 11 additions & 1 deletion R/SurrogateLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,17 @@ SurrogateLearner = R6Class("SurrogateLearner",
assert_xdt(xdt)
xdt = fix_xdt_missing(xdt, cols_x = self$cols_x, archive = self$archive)

pred = self$learner$predict_newdata(newdata = xdt)
# speeding up some checks by constructing the predict task directly instead of relying on predict_newdata
task = self$learner$state$train_task$clone()
set(xdt, j = task$target_names, value = NA_real_) # tasks only have features and the target but we have to set the target to NA
newdata = as_data_backend(xdt)
task$backend = newdata
task$row_roles$use = task$backend$rownames
pred = self$learner$predict(task)

# slow
#pred = self$learner$predict_newdata(newdata = xdt)

if (self$learner$predict_type == "se") {
data.table(mean = pred$response, se = pred$se)
} else {
Expand Down
19 changes: 18 additions & 1 deletion R/SurrogateLearnerCollection.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,31 @@ SurrogateLearnerCollection = R6Class("SurrogateLearnerCollection",
assert_xdt(xdt)
xdt = fix_xdt_missing(xdt, cols_x = self$cols_x, archive = self$archive)

# speeding up some checks by constructing the predict task directly instead of relying on predict_newdata
preds = lapply(self$learner, function(learner) {
pred = learner$predict_newdata(newdata = xdt)
task = learner$state$train_task$clone()
set(xdt, j = task$target_names, value = NA_real_) # tasks only have features and the target but we have to set the target to NA
newdata = as_data_backend(xdt)
task$backend = newdata
task$row_roles$use = task$backend$rownames
pred = learner$predict(task)
if (learner$predict_type == "se") {
data.table(mean = pred$response, se = pred$se)
} else {
data.table(mean = pred$response)
}
})

# slow
#preds = lapply(self$learner, function(learner) {
# pred = learner$predict_newdata(newdata = xdt)
# if (learner$predict_type == "se") {
# data.table(mean = pred$response, se = pred$se)
# } else {
# data.table(mean = pred$response)
# }
#})

names(preds) = names(self$learner)
preds
}
Expand Down

0 comments on commit ce2e16f

Please sign in to comment.