Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check with new paradox #136

Merged
merged 19 commits into from
Feb 29, 2024
Merged
3 changes: 2 additions & 1 deletion .github/workflows/dev-cmd-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
- {os: ubuntu-latest, r: 'release', dev-package: 'mlr-org/bbotk'}
- {os: ubuntu-latest, r: 'release', dev-package: 'mlr-org/mlr3'}
- {os: ubuntu-latest, r: 'release', dev-package: 'mlr-org/mlr3tuning'}
- {os: ubuntu-latest, r: 'release', dev-package: "mlr-org/mlr3tuning', 'mlr-org/mlr3learners', 'mlr-org/mlr3pipelines', 'mlr-org/bbotk', 'mlr-org/paradox"}

steps:
- uses: actions/checkout@v3
Expand All @@ -43,7 +44,7 @@ jobs:
needs: check

- name: Install dev versions
run: pak::pkg_install('${{ matrix.config.dev-package }}')
run: pak::pkg_install(c('${{ matrix.config.dev-package }}'))
shell: Rscript {0}

- uses: r-lib/actions/check-r-package@v2
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# mlr3mbo 0.2.1.9000

* refactor: compatibility with upcoming paradox upgrade.
* feat: `OptimizerMbo` and `TunerMbo` now update the `Surrogate` a final time after the optimization process finished to
ensure that the `Surrogate` correctly reflects the state of being trained on all data seen during optimization.
* fix: `AcqFunction` domain construction now respects `Surrogate` cols_x field.
Expand Down
3 changes: 2 additions & 1 deletion R/AcqFunction.R
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ AcqFunction = R6Class("AcqFunction",
codomain = generate_acq_codomain(rhs, id = self$id, direction = self$direction)
self$surrogate_max_to_min = surrogate_mult_max_to_min(rhs)
domain = generate_acq_domain(rhs)
self$codomain = Codomain$new(codomain$params) # lazy initialization requires this
# lazy initialization requires this:
self$codomain = Codomain$new(get0("domains", codomain, ifnotfound = codomain$params)) # get0 for old paradox
self$domain = domain
}
},
Expand Down
8 changes: 4 additions & 4 deletions R/AcqFunctionEHVIGH.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ AcqFunctionEHVIGH = R6Class("AcqFunctionEHVIGH",
assert_r6(surrogate, "SurrogateLearnerCollection", null.ok = TRUE)
assert_int(k, lower = 2L)

constants = ParamSet$new(list(
ParamInt$new("k", lower = 2L, default = 15L),
ParamDbl$new("r", lower = 0, upper = 1, default = 0.2)
))
constants = ps(
k = p_int(lower = 2L, default = 15L),
r = p_dbl(lower = 0, upper = 1, default = 0.2)
)
constants$values$k = k
constants$values$r = r

Expand Down
8 changes: 4 additions & 4 deletions R/AcqFunctionSmsEgo.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ AcqFunctionSmsEgo = R6Class("AcqFunctionSmsEgo",
assert_number(lambda, lower = 1, finite = TRUE)
assert_number(epsilon, lower = 0, finite = TRUE, null.ok = TRUE)

constants = ParamSet$new(list(
ParamDbl$new("lambda", lower = 0, default = 1),
ParamDbl$new("epsilon", lower = 0, default = NULL, special_vals = list(NULL)) # for NULL, it will be calculated dynamically
))
constants = ps(
lambda = p_dbl(lower = 0, default = 1),
epsilon = p_dbl(lower = 0, default = NULL, special_vals = list(NULL)) # for NULL, it will be calculated dynamically
)
constants$values$lambda = lambda
constants$values$epsilon = epsilon

Expand Down
14 changes: 7 additions & 7 deletions R/AcqOptimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ AcqOptimizer = R6Class("AcqOptimizer",
self$optimizer = assert_r6(optimizer, "Optimizer")
self$terminator = assert_r6(terminator, "Terminator")
self$acq_function = assert_r6(acq_function, "AcqFunction", null.ok = TRUE)
ps = ParamSet$new(list(
ParamInt$new("n_candidates", lower = 1, default = 1L),
ParamFct$new("logging_level", levels = c("fatal", "error", "warn", "info", "debug", "trace"), default = "warn"),
ParamLgl$new("warmstart", default = FALSE),
ParamInt$new("warmstart_size", lower = 1L, special_vals = list("all")),
ParamLgl$new("skip_already_evaluated", default = TRUE),
ParamLgl$new("catch_errors", default = TRUE))
ps = ps(
n_candidates = p_int(lower = 1, default = 1L),
logging_level = p_fct(levels = c("fatal", "error", "warn", "info", "debug", "trace"), default = "warn"),
warmstart = p_lgl(default = FALSE),
warmstart_size = p_int(lower = 1L, special_vals = list("all")),
skip_already_evaluated = p_lgl(default = TRUE),
catch_errors = p_lgl(default = TRUE)
)
ps$values = list(n_candidates = 1, logging_level = "warn", warmstart = FALSE, skip_already_evaluated = TRUE, catch_errors = TRUE)
ps$add_dep("warmstart_size", on = "warmstart", cond = CondEqual$new(TRUE))
Expand Down
2 changes: 1 addition & 1 deletion R/Surrogate.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Surrogate = R6Class("Surrogate",
private$.cols_x = assert_character(cols_x, min.len = 1L, null.ok = TRUE)
private$.cols_y = cols_y = assert_character(cols_y, min.len = 1L, null.ok = TRUE)
assert_r6(param_set, classes = "ParamSet")
assert_r6(param_set$params$catch_errors, classes = "ParamLgl")
stopifnot(param_set$class[["catch_errors"]] == "ParamLgl")
private$.param_set = param_set
},

Expand Down
10 changes: 5 additions & 5 deletions R/SurrogateLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ SurrogateLearner = R6Class("SurrogateLearner",
assert_character(cols_x, min.len = 1L, null.ok = TRUE)
assert_string(col_y, null.ok = TRUE)

ps = ParamSet$new(list(
ParamLgl$new("assert_insample_perf"),
ParamUty$new("perf_measure", custom_check = function(x) check_r6(x, classes = "MeasureRegr")), # FIXME: actually want check_measure
ParamDbl$new("perf_threshold", lower = -Inf, upper = Inf),
ParamLgl$new("catch_errors"))
ps = ps(
assert_insample_perf = p_lgl(),
perf_measure = p_uty(custom_check = function(x) check_r6(x, classes = "MeasureRegr")), # FIXME: actually want check_measure
perf_threshold = p_dbl(lower = -Inf, upper = Inf),
catch_errors = p_lgl()
)
ps$values = list(assert_insample_perf = FALSE, catch_errors = TRUE)
ps$add_dep("perf_measure", on = "assert_insample_perf", cond = CondEqual$new(TRUE))
Expand Down
10 changes: 5 additions & 5 deletions R/SurrogateLearnerCollection.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ SurrogateLearnerCollection = R6Class("SurrogateLearnerCollection",
assert_character(cols_x, min.len = 1L, null.ok = TRUE)
assert_character(cols_y, len = length(learners), null.ok = TRUE)

ps = ParamSet$new(list(
ParamLgl$new("assert_insample_perf"),
ParamUty$new("perf_measures", custom_check = function(x) check_list(x, types = "MeasureRegr", any.missing = FALSE, len = length(learners))), # FIXME: actually want check_measures
ParamUty$new("perf_thresholds", custom_check = function(x) check_double(x, lower = -Inf, upper = Inf, any.missing = FALSE, len = length(learners))),
ParamLgl$new("catch_errors"))
ps = ps(
assert_insample_perf = p_lgl(),
perf_measures = p_uty(custom_check = function(x) check_list(x, types = "MeasureRegr", any.missing = FALSE, len = length(learners))), # FIXME: actually want check_measures
perf_thresholds = p_uty(custom_check = function(x) check_double(x, lower = -Inf, upper = Inf, any.missing = FALSE, len = length(learners))),
catch_errors = p_lgl()
)
ps$values = list(assert_insample_perf = FALSE, catch_errors = TRUE)
ps$add_dep("perf_measures", on = "assert_insample_perf", cond = CondEqual$new(TRUE))
Expand Down
23 changes: 15 additions & 8 deletions R/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,28 @@ generate_acq_codomain = function(surrogate, id, direction = "same") {
if (surrogate$archive$codomain$length > 1L) {
stop("Not supported yet.") # FIXME: But should be?
}
tags = surrogate$archive$codomain$params[[1L]]$tags
tags = surrogate$archive$codomain$tags[[1L]]
tags = tags[tags %in% c("minimize", "maximize")] # only filter out the relevant one
} else {
tags = direction
}
codomain = ParamSet$new(list(
ParamDbl$new(id, tags = tags)
))
codomain
do.call(ps, structure(list(p_dbl(tags = tags)), names = id))
}

generate_acq_domain = function(surrogate) {
assert_r6(surrogate$archive, classes = "Archive")
domain = surrogate$archive$search_space$clone(deep = TRUE)$subset(surrogate$cols_x)
domain$trafo = NULL
if ("set_id" %in% names(ps())) {
# old paradox
domain = surrogate$archive$search_space$clone(deep = TRUE)$subset(surrogate$cols_x)
domain$trafo = NULL
} else {
# get "domain" objects, set their .trafo-entry to NULL individually
dms = lapply(surrogate$archive$search_space$domains[surrogate$cols_x], function(x) {
x$.trafo[1] = list(NULL)
x
})
domain = do.call(ps, dms)
}
domain
}

Expand Down Expand Up @@ -130,7 +137,7 @@ check_learner_surrogate = function(learner) {
return(TRUE)
}
}

"Must inherit from class 'Learner' or be a list of elements inheriting from class 'Learner'"
}

Expand Down
3 changes: 2 additions & 1 deletion man-roxygen/field_param_classes.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#' @field param_classes (`character()`)\cr
#' Supported parameter classes that the optimizer can optimize.
#' Determined based on the `surrogate` and the `acq_optimizer`.
#' Subclasses of [paradox::Param].
#' This corresponds to the values given by a [paradox::ParamSet]'s
#' `$class` field.
3 changes: 2 additions & 1 deletion man/mlr_optimizers_mbo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/mlr_tuners_mbo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

57 changes: 32 additions & 25 deletions tests/testthat/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@ lapply(list.files(system.file("testthat", package = "mlr3"),
pattern = "^helper.*\\.[rR]", full.names = TRUE), source)

# Simple 1D Functions
PS_1D = ParamSet$new(list(
ParamDbl$new("x", lower = -1, upper = 1)
))
PS_1D = ps(
x = p_dbl(lower = -1, upper = 1)
)
FUN_1D = function(xs) {
list(y = as.numeric(xs)^2)
}
FUN_1D_CODOMAIN = ParamSet$new(list(ParamDbl$new("y", tags = "minimize")))
FUN_1D_CODOMAIN = ps(y = p_dbl(tags = "minimize"))
OBJ_1D = ObjectiveRFun$new(fun = FUN_1D, domain = PS_1D, codomain = FUN_1D_CODOMAIN, properties = "single-crit")

FUN_1D_2 = function(xs) {
list(y1 = as.numeric(xs)^2, y2 = - sqrt(abs(as.numeric(xs))))
}
FUN_1D_2_CODOMAIN = ParamSet$new(list(ParamDbl$new("y1", tags = "minimize"), ParamDbl$new("y2", tags = "minimize")))
FUN_1D_2_CODOMAIN = ps(y1 = p_dbl(tags = "minimize"), y2 = p_dbl(tags = "minimize"))
OBJ_1D_2 = ObjectiveRFun$new(fun = FUN_1D_2, domain = PS_1D, codomain = FUN_1D_2_CODOMAIN, properties = "multi-crit")

# Simple 1D Functions with noise
Expand All @@ -29,12 +29,12 @@ FUN_1D_2_NOISY = function(xs) {
OBJ_1D_2_NOISY = ObjectiveRFun$new(fun = FUN_1D_2, domain = PS_1D, codomain = FUN_1D_2_CODOMAIN, properties = c("multi-crit", "noisy"))

# Mixed 1D Functions
PS_1D_MIXED = ParamSet$new(list(
ParamDbl$new("x1", -5, 5),
ParamFct$new("x2", levels = c("a", "b", "c")),
ParamInt$new("x3", 1L, 2L),
ParamLgl$new("x4")
))
PS_1D_MIXED = ps(
x1 = p_dbl(-5, 5),
x2 = p_fct(c("a", "b", "c")),
x3 = p_int(1L, 2L),
x4 = p_lgl()
)
PS_1D_MIXED_DEPS = PS_1D_MIXED$clone(deep = TRUE)
PS_1D_MIXED_DEPS$add_dep("x2", on = "x4", cond = CondEqual$new(TRUE))

Expand All @@ -50,23 +50,19 @@ FUN_1D_2_MIXED = function(xs) {
OBJ_1D_2_MIXED = ObjectiveRFun$new(fun = FUN_1D_2_MIXED, domain = PS_1D_MIXED, codomain = FUN_1D_2_CODOMAIN, properties = "multi-crit")

# Simple 2D Functions
PS_2D = ParamSet$new(list(
ParamDbl$new("x1", lower = -1, upper = 1),
ParamDbl$new("x2", lower = -1, upper = 1)
))
PS_2D_trafo = ParamSet$new(list(
ParamDbl$new("x1", lower = -1, upper = 1),
ParamDbl$new("x2", lower = -1, upper = 1)
))
PS_2D_trafo$trafo = function(x, param_set) {
x$x2 = x$x2 ^ 2
x
}
PS_2D = ps(
x1 = p_dbl(lower = -1, upper = 1),
x2 = p_dbl(lower = -1, upper = 1)
)
PS_2D_trafo = ps(
x1 = p_dbl(lower = -1, upper = 1),
x2 = p_dbl(lower = -1, upper = 1, trafo = function(x) x^2)
)
FUN_2D = function(xs) {
y = sum(as.numeric(xs)^2)
list(y = y)
}
FUN_2D_CODOMAIN = ParamSet$new(list(ParamDbl$new("y", tags = c("minimize", "random_tag"))))
FUN_2D_CODOMAIN = ps(y = p_dbl(tags = c("minimize", "random_tag")))
OBJ_2D = ObjectiveRFun$new(fun = FUN_2D, domain = PS_2D, properties = "single-crit")

# Simple 2D Function with noise
Expand Down Expand Up @@ -113,7 +109,7 @@ OptimizerError = R6Class("OptimizerError",

initialize = function() {
super$initialize(
param_set = ParamSet$new(),
param_set = ps(),
param_classes = c("ParamLgl", "ParamInt", "ParamDbl", "ParamFct"),
properties = c("dependencies", "single-crit", "multi-crit")
)
Expand Down Expand Up @@ -199,3 +195,14 @@ expect_acqfunction = function(acqf) {
expect_man_exists(acqf$man)
}


sortnames = function(x) {
if (!is.null(names(x))) {
x <- x[order(names(x), decreasing = TRUE)]
}
x
}

expect_equal_sorted = function(x, y, ...) {
expect_equal(sortnames(x), sortnames(y), ...)
}
12 changes: 6 additions & 6 deletions tests/testthat/test_AcqOptimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ test_that("AcqOptimizer param_set", {
acqopt = AcqOptimizer$new(opt("random_search", batch_size = 1L), trm("evals", n_evals = 1L))
expect_r6(acqopt$param_set, "ParamSet")
expect_setequal(acqopt$param_set$ids(), c("n_candidates", "logging_level", "warmstart", "warmstart_size", "skip_already_evaluated", "catch_errors"))
expect_r6(acqopt$param_set$params$n_candidates, "ParamInt")
expect_r6(acqopt$param_set$params$logging_level, "ParamFct")
expect_r6(acqopt$param_set$params$warmstart, "ParamLgl")
expect_r6(acqopt$param_set$params$warmstart_size, "ParamInt")
expect_r6(acqopt$param_set$params$skip_already_evaluated, "ParamLgl")
expect_r6(acqopt$param_set$params$catch_errors, "ParamLgl")
expect_equal(acqopt$param_set$class[["n_candidates"]], "ParamInt")
expect_equal(acqopt$param_set$class[["logging_level"]], "ParamFct")
expect_equal(acqopt$param_set$class[["warmstart"]], "ParamLgl")
expect_equal(acqopt$param_set$class[["warmstart_size"]], "ParamInt")
expect_equal(acqopt$param_set$class[["skip_already_evaluated"]], "ParamLgl")
expect_equal(acqopt$param_set$class[["catch_errors"]], "ParamLgl")
expect_error({acqopt$param_set = list()}, regexp = "param_set is read-only.")
})

Expand Down
10 changes: 5 additions & 5 deletions tests/testthat/test_SurrogateLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ test_that("SurrogateLearner API works", {
# upgrading error class works
surrogate = SurrogateLearner$new(LearnerRegrError$new(), archive = inst$archive)
expect_error(surrogate$update(), class = "surrogate_update_error")

surrogate$param_set$values$catch_errors = FALSE
expect_error(surrogate$optimize(), class = "simpleError")

Expand Down Expand Up @@ -51,10 +51,10 @@ test_that("param_set", {
surrogate = SurrogateLearner$new(learner = REGR_FEATURELESS, archive = inst$archive)
expect_r6(surrogate$param_set, "ParamSet")
expect_setequal(surrogate$param_set$ids(), c("assert_insample_perf", "perf_measure", "perf_threshold", "catch_errors"))
expect_r6(surrogate$param_set$params$assert_insample_perf, "ParamLgl")
expect_r6(surrogate$param_set$params$perf_measure, "ParamUty")
expect_r6(surrogate$param_set$params$perf_threshold, "ParamDbl")
expect_r6(surrogate$param_set$params$catch_errors, "ParamLgl")
expect_equal(surrogate$param_set$class[["assert_insample_perf"]], "ParamLgl")
expect_equal(surrogate$param_set$class[["perf_measure"]], "ParamUty")
expect_equal(surrogate$param_set$class[["perf_threshold"]], "ParamDbl")
expect_equal(surrogate$param_set$class[["catch_errors"]], "ParamLgl")
expect_error({surrogate$param_set = list()}, regexp = "param_set is read-only.")
})

Expand Down
10 changes: 5 additions & 5 deletions tests/testthat/test_SurrogateLearnerCollection.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ test_that("SurrogateLearnerCollection API works", {
# upgrading error class works
surrogate = SurrogateLearnerCollection$new(learners = list(LearnerRegrError$new(), LearnerRegrError$new()), archive = inst$archive)
expect_error(surrogate$update(), class = "surrogate_update_error")

surrogate$param_set$values$catch_errors = FALSE
expect_error(surrogate$optimize(), class = "simpleError")

Expand Down Expand Up @@ -61,10 +61,10 @@ test_that("param_set", {
surrogate = SurrogateLearnerCollection$new(learner = list(REGR_FEATURELESS, REGR_FEATURELESS$clone(deep = TRUE)), archive = inst$archive)
expect_r6(surrogate$param_set, "ParamSet")
expect_setequal(surrogate$param_set$ids(), c("assert_insample_perf", "perf_measures", "perf_thresholds", "catch_errors"))
expect_r6(surrogate$param_set$params$assert_insample_perf, "ParamLgl")
expect_r6(surrogate$param_set$params$perf_measure, "ParamUty")
expect_r6(surrogate$param_set$params$perf_threshold, "ParamUty")
expect_r6(surrogate$param_set$params$catch_errors, "ParamLgl")
expect_equal(surrogate$param_set$class[["assert_insample_perf"]], "ParamLgl")
expect_equal(surrogate$param_set$class[["perf_measures"]], "ParamUty")
expect_equal(surrogate$param_set$class[["perf_thresholds"]], "ParamUty")
expect_equal(surrogate$param_set$class[["catch_errors"]], "ParamLgl")
expect_error({surrogate$param_set = list()}, regexp = "param_set is read-only.")
})

Expand Down
Loading
Loading