From af152855cf99acd05b0f1f12ab7a34dcfd64336b Mon Sep 17 00:00:00 2001 From: mb706 Date: Sun, 14 Jan 2024 17:33:48 +0100 Subject: [PATCH] adjust tests --- R/helper.R | 2 +- tests/testthat/helper.R | 11 +++++++++++ tests/testthat/test_AcqOptimizer.R | 12 ++++++------ tests/testthat/test_SurrogateLearner.R | 10 +++++----- tests/testthat/test_mbo_defaults.R | 14 +++++++------- 5 files changed, 30 insertions(+), 19 deletions(-) diff --git a/R/helper.R b/R/helper.R index 6f6bee82..cc6d71c6 100644 --- a/R/helper.R +++ b/R/helper.R @@ -6,7 +6,7 @@ generate_acq_codomain = function(surrogate, id, direction = "same") { if (surrogate$archive$codomain$length > 1L) { stop("Not supported yet.") # FIXME: But should be? } - tags = surrogate$archive$codomain$params[[1L]]$tags + tags = surrogate$archive$codomain$tags[[1L]] tags = tags[tags %in% c("minimize", "maximize")] # only filter out the relevant one } else { tags = direction diff --git a/tests/testthat/helper.R b/tests/testthat/helper.R index 692301e2..fb330231 100644 --- a/tests/testthat/helper.R +++ b/tests/testthat/helper.R @@ -195,3 +195,14 @@ expect_acqfunction = function(acqf) { expect_man_exists(acqf$man) } + +sortnames = function(x) { + if (!is.null(names(x))) { + x <- x[order(names(x), decreasing = TRUE)] + } + x +} + +expect_equal_sorted = function(x, y, ...) { + expect_equal(sortnames(x), sortnames(y), ...) +} diff --git a/tests/testthat/test_AcqOptimizer.R b/tests/testthat/test_AcqOptimizer.R index f86db27b..681a8251 100644 --- a/tests/testthat/test_AcqOptimizer.R +++ b/tests/testthat/test_AcqOptimizer.R @@ -84,12 +84,12 @@ test_that("AcqOptimizer param_set", { acqopt = AcqOptimizer$new(opt("random_search", batch_size = 1L), trm("evals", n_evals = 1L)) expect_r6(acqopt$param_set, "ParamSet") expect_setequal(acqopt$param_set$ids(), c("n_candidates", "logging_level", "warmstart", "warmstart_size", "skip_already_evaluated", "catch_errors")) - expect_r6(acqopt$param_set$params$n_candidates, "ParamInt") - expect_r6(acqopt$param_set$params$logging_level, "ParamFct") - expect_r6(acqopt$param_set$params$warmstart, "ParamLgl") - expect_r6(acqopt$param_set$params$warmstart_size, "ParamInt") - expect_r6(acqopt$param_set$params$skip_already_evaluated, "ParamLgl") - expect_r6(acqopt$param_set$params$catch_errors, "ParamLgl") + expect_equal(acqopt$param_set$class[["n_candidates"]], "ParamInt") + expect_equal(acqopt$param_set$class[["logging_level"]], "ParamFct") + expect_equal(acqopt$param_set$class[["warmstart"]], "ParamLgl") + expect_equal(acqopt$param_set$class[["warmstart_size"]], "ParamInt") + expect_equal(acqopt$param_set$class[["skip_already_evaluated"]], "ParamLgl") + expect_equal(acqopt$param_set$class[["catch_errors"]], "ParamLgl") expect_error({acqopt$param_set = list()}, regexp = "param_set is read-only.") }) diff --git a/tests/testthat/test_SurrogateLearner.R b/tests/testthat/test_SurrogateLearner.R index 4a45b6a9..08f46d7e 100644 --- a/tests/testthat/test_SurrogateLearner.R +++ b/tests/testthat/test_SurrogateLearner.R @@ -17,7 +17,7 @@ test_that("SurrogateLearner API works", { # upgrading error class works surrogate = SurrogateLearner$new(LearnerRegrError$new(), archive = inst$archive) expect_error(surrogate$update(), class = "surrogate_update_error") - + surrogate$param_set$values$catch_errors = FALSE expect_error(surrogate$optimize(), class = "simpleError") @@ -51,10 +51,10 @@ test_that("param_set", { surrogate = SurrogateLearner$new(learner = REGR_FEATURELESS, archive = inst$archive) expect_r6(surrogate$param_set, "ParamSet") expect_setequal(surrogate$param_set$ids(), c("assert_insample_perf", "perf_measure", "perf_threshold", "catch_errors")) - expect_r6(surrogate$param_set$params$assert_insample_perf, "ParamLgl") - expect_r6(surrogate$param_set$params$perf_measure, "ParamUty") - expect_r6(surrogate$param_set$params$perf_threshold, "ParamDbl") - expect_r6(surrogate$param_set$params$catch_errors, "ParamLgl") + expect_equal(surrogate$param_set$class[["assert_insample_perf"]], "ParamLgl") + expect_equal(surrogate$param_set$class[["perf_measure"]], "ParamUty") + expect_equal(surrogate$param_set$class[["perf_threshold"]], "ParamDbl") + expect_equal(surrogate$param_set$class[["catch_errors"]], "ParamLgl") expect_error({surrogate$param_set = list()}, regexp = "param_set is read-only.") }) diff --git a/tests/testthat/test_mbo_defaults.R b/tests/testthat/test_mbo_defaults.R index 26285115..6ae51460 100644 --- a/tests/testthat/test_mbo_defaults.R +++ b/tests/testthat/test_mbo_defaults.R @@ -21,7 +21,7 @@ test_that("default_surrogate", { surrogate = default_surrogate(MAKE_INST_1D()) expect_r6(surrogate, "SurrogateLearner") expect_r6(surrogate$learner, "LearnerRegrKM") - expect_equal(surrogate$learner$param_set$values, + expect_equal_sorted(surrogate$learner$param_set$values, list(covtype = "matern5_2", optim.method = "gen", control = list(trace = FALSE), nugget.stability = 1e-08)) expect_equal(surrogate$learner$encapsulate, c(train = "evaluate", predict = "evaluate")) expect_r6(surrogate$learner$fallback, "LearnerRegrRanger") @@ -30,7 +30,7 @@ test_that("default_surrogate", { surrogate = default_surrogate(MAKE_INST_1D_NOISY()) expect_r6(surrogate, "SurrogateLearner") expect_r6(surrogate$learner, "LearnerRegrKM") - expect_equal(surrogate$learner$param_set$values, + expect_equal_sorted(surrogate$learner$param_set$values, list(covtype = "matern5_2", optim.method = "gen", control = list(trace = FALSE), nugget.estim = TRUE, jitter = 1e-12)) expect_equal(surrogate$learner$encapsulate, c(train = "evaluate", predict = "evaluate")) expect_r6(surrogate$learner$fallback, "LearnerRegrRanger") @@ -39,7 +39,7 @@ test_that("default_surrogate", { surrogate = default_surrogate(MAKE_INST(OBJ_1D_2, search_space = PS_1D)) expect_r6(surrogate, "SurrogateLearnerCollection") expect_list(surrogate$learner, types = "LearnerRegrKM") - expect_equal(surrogate$learner[[1L]]$param_set$values, + expect_equal_sorted(surrogate$learner[[1L]]$param_set$values, list(covtype = "matern5_2", optim.method = "gen", control = list(trace = FALSE), nugget.stability = 1e-08)) expect_equal(surrogate$learner[[1L]]$encapsulate, c(train = "evaluate", predict = "evaluate")) expect_r6(surrogate$learner[[1L]]$fallback, "LearnerRegrRanger") @@ -51,7 +51,7 @@ test_that("default_surrogate", { surrogate = default_surrogate(MAKE_INST(OBJ_1D_2_NOISY, search_space = PS_1D)) expect_r6(surrogate, "SurrogateLearnerCollection") expect_list(surrogate$learner, types = "LearnerRegrKM") - expect_equal(surrogate$learner[[1L]]$param_set$values, + expect_equal_sorted(surrogate$learner[[1L]]$param_set$values, list(covtype = "matern5_2", optim.method = "gen", control = list(trace = FALSE), nugget.estim = TRUE, jitter = 1e-12)) expect_equal(surrogate$learner[[1L]]$encapsulate, c(train = "evaluate", predict = "evaluate")) expect_r6(surrogate$learner[[1L]]$fallback, "LearnerRegrRanger") @@ -63,7 +63,7 @@ test_that("default_surrogate", { surrogate = default_surrogate(MAKE_INST(OBJ_1D_MIXED, search_space = PS_1D_MIXED)) expect_r6(surrogate, "SurrogateLearner") expect_r6(surrogate$learner, "LearnerRegrRanger") - expect_equal(surrogate$learner$param_set$values, + expect_equal_sorted(surrogate$learner$param_set$values, list(num.threads = 1L, num.trees = 100L, keep.inbag = TRUE, se.method = "jack")) expect_equal(surrogate$learner$encapsulate, c(train = "evaluate", predict = "evaluate")) expect_r6(surrogate$learner$fallback, "LearnerRegrRanger") @@ -72,7 +72,7 @@ test_that("default_surrogate", { surrogate = default_surrogate(MAKE_INST(OBJ_1D_2_MIXED, search_space = PS_1D_MIXED)) expect_r6(surrogate, "SurrogateLearnerCollection") expect_list(surrogate$learner, types = "LearnerRegrRanger") - expect_equal(surrogate$learner[[1L]]$param_set$values, + expect_equal_sorted(surrogate$learner[[1L]]$param_set$values, list(num.threads = 1L, num.trees = 100L, keep.inbag = TRUE, se.method = "jack")) expect_equal(surrogate$learner[[1L]]$encapsulate, c(train = "evaluate", predict = "evaluate")) expect_r6(surrogate$learner[[1L]]$fallback, "LearnerRegrRanger") @@ -85,7 +85,7 @@ test_that("default_surrogate", { expect_r6(surrogate, "SurrogateLearner") expect_r6(surrogate$learner, "GraphLearner") expect_equal(surrogate$learner$graph$ids(), c("imputesample", "imputeoor", "colapply", "regr.ranger")) - expect_equal(surrogate$learner$param_set$values, + expect_equal_sorted(surrogate$learner$param_set$values, list(imputesample.affect_columns = mlr3pipelines::selector_type("logical"), imputeoor.min = TRUE, imputeoor.offset = 1,