Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
sumny committed Jun 18, 2024
1 parent 4018fc5 commit d509304
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 2 deletions.
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ importFrom(R6,R6Class)
importFrom(stats,dnorm)
importFrom(stats,pnorm)
importFrom(stats,quantile)
importFrom(stats,rexp)
importFrom(stats,runif)
importFrom(stats,setNames)
importFrom(utils,bibentry)
Expand Down
1 change: 1 addition & 0 deletions R/SurrogateLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ SurrogateLearner = R6Class("SurrogateLearner",
# Train learner with new data.
# Also calculates the insample performance based on the `perf_measure` hyperparameter if `assert_insample_perf = TRUE`.
.update = function() {
xydt = self$archive$data[, c(self$cols_x, self$cols_y), with = FALSE]
task = TaskRegr$new(id = "surrogate_task", backend = xydt, target = self$cols_y)
assert_learnable(task, learner = self$learner)
self$learner$train(task)
Expand Down
11 changes: 11 additions & 0 deletions tests/testthat/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,14 @@ expect_acqfunction = function(acqf) {
expect_man_exists(acqf$man)
}

sortnames = function(x) {
if (!is.null(names(x))) {
x = x[order(names(x), decreasing = TRUE)]
}
x
}

expect_equal_sorted = function(x, y, ...) {
expect_equal(sortnames(x), sortnames(y), ...)
}

2 changes: 1 addition & 1 deletion tests/testthat/test_SurrogateLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ test_that("param_set", {
inst = MAKE_INST_1D()
surrogate = SurrogateLearner$new(learner = REGR_FEATURELESS, archive = inst$archive)
expect_r6(surrogate$param_set, "ParamSet")
expect_setequal(surrogate$param_set$ids(), c("assert_insample_perf", "perf_measure", "perf_threshold", "catch_errors", "impute_missings"))
expect_setequal(surrogate$param_set$ids(), c("assert_insample_perf", "perf_measure", "perf_threshold", "catch_errors"))
expect_equal(surrogate$param_set$class[["assert_insample_perf"]], "ParamLgl")
expect_equal(surrogate$param_set$class[["perf_measure"]], "ParamUty")
expect_equal(surrogate$param_set$class[["perf_threshold"]], "ParamDbl")
Expand Down

0 comments on commit d509304

Please sign in to comment.