diff --git a/R/SurrogateLearner.R b/R/SurrogateLearner.R index 18abc334..0891833e 100644 --- a/R/SurrogateLearner.R +++ b/R/SurrogateLearner.R @@ -113,7 +113,17 @@ SurrogateLearner = R6Class("SurrogateLearner", assert_xdt(xdt) xdt = fix_xdt_missing(xdt, cols_x = self$cols_x, archive = self$archive) - pred = self$learner$predict_newdata(newdata = xdt) + # speeding up some checks by constructing the predict task directly instead of relying on predict_newdata + task = self$learner$state$train_task$clone() + set(xdt, j = task$target_names, value = NA_real_) # tasks only have features and the target but we have to set the target to NA + newdata = as_data_backend(xdt) + task$backend = newdata + task$row_roles$use = task$backend$rownames + pred = self$learner$predict(task) + + # slow + #pred = self$learner$predict_newdata(newdata = xdt) + if (self$learner$predict_type == "se") { data.table(mean = pred$response, se = pred$se) } else { diff --git a/R/SurrogateLearnerCollection.R b/R/SurrogateLearnerCollection.R index 74aacff4..34684c62 100644 --- a/R/SurrogateLearnerCollection.R +++ b/R/SurrogateLearnerCollection.R @@ -130,14 +130,31 @@ SurrogateLearnerCollection = R6Class("SurrogateLearnerCollection", assert_xdt(xdt) xdt = fix_xdt_missing(xdt, cols_x = self$cols_x, archive = self$archive) + # speeding up some checks by constructing the predict task directly instead of relying on predict_newdata preds = lapply(self$learner, function(learner) { - pred = learner$predict_newdata(newdata = xdt) + task = learner$state$train_task$clone() + set(xdt, j = task$target_names, value = NA_real_) # tasks only have features and the target but we have to set the target to NA + newdata = as_data_backend(xdt) + task$backend = newdata + task$row_roles$use = task$backend$rownames + pred = learner$predict(task) if (learner$predict_type == "se") { data.table(mean = pred$response, se = pred$se) } else { data.table(mean = pred$response) } }) + + # slow + #preds = lapply(self$learner, function(learner) { + # pred = learner$predict_newdata(newdata = xdt) + # if (learner$predict_type == "se") { + # data.table(mean = pred$response, se = pred$se) + # } else { + # data.table(mean = pred$response) + # } + #}) + names(preds) = names(self$learner) preds }