From 2d9907701972edde7415b1f762d446f6884c5785 Mon Sep 17 00:00:00 2001 From: Lennart Schneider Date: Tue, 13 Aug 2024 17:47:21 +0200 Subject: [PATCH] feat: allow EI to be adjusted by epsilon to strengthen exploration (#154) --- DESCRIPTION | 2 +- R/AcqFunctionCB.R | 4 ++-- R/AcqFunctionEI.R | 23 +++++++++++++++++++---- man/mlr_acqfunctions_ei.Rd | 15 ++++++++++++++- tests/testthat/test_AcqFunctionCB.R | 3 +++ tests/testthat/test_AcqFunctionEHVIGH.R | 3 +++ tests/testthat/test_AcqFunctionEI.R | 3 +++ tests/testthat/test_AcqFunctionSmsEgo.R | 3 +++ 8 files changed, 48 insertions(+), 8 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 2a58168e..5cf62f96 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -72,7 +72,7 @@ Config/testthat/edition: 3 Config/testthat/parallel: false NeedsCompilation: yes Roxygen: list(markdown = TRUE, r6 = TRUE) -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 Collate: 'mlr_acqfunctions.R' 'AcqFunction.R' diff --git a/R/AcqFunctionCB.R b/R/AcqFunctionCB.R index 14a980c3..81328e05 100644 --- a/R/AcqFunctionCB.R +++ b/R/AcqFunctionCB.R @@ -76,8 +76,8 @@ AcqFunctionCB = R6Class("AcqFunctionCB", constants = list(...) lambda = constants$lambda p = self$surrogate$predict(xdt) - res = p$mean - self$surrogate_max_to_min * lambda * p$se - data.table(acq_cb = res) + cb = p$mean - self$surrogate_max_to_min * lambda * p$se + data.table(acq_cb = cb) } ) ) diff --git a/R/AcqFunctionEI.R b/R/AcqFunctionEI.R index 1782b972..f03ec39b 100644 --- a/R/AcqFunctionEI.R +++ b/R/AcqFunctionEI.R @@ -9,6 +9,13 @@ #' @description #' Expected Improvement. #' +#' @section Parameters: +#' * `"epsilon"` (`numeric(1)`)\cr +#' \eqn{\epsilon} value used to determine the amount of exploration. +#' Higher values result in the importance of improvements predicted by the posterior mean +#' decreasing relative to the importance of potential improvements in regions of high predictive uncertainty. +#' Defaults to `0` (standard Expected Improvement). +#' #' @references #' * `r format_bib("jones_1998")` #' @@ -60,9 +67,15 @@ AcqFunctionEI = R6Class("AcqFunctionEI", #' Creates a new instance of this [R6][R6::R6Class] class. #' #' @param surrogate (`NULL` | [SurrogateLearner]). - initialize = function(surrogate = NULL) { + #' @param epsilon (`numeric(1)`). + initialize = function(surrogate = NULL, epsilon = 0) { assert_r6(surrogate, "SurrogateLearner", null.ok = TRUE) - super$initialize("acq_ei", surrogate = surrogate, requires_predict_type_se = TRUE, direction = "maximize", label = "Expected Improvement", man = "mlr3mbo::mlr_acqfunctions_ei") + assert_number(epsilon, lower = 0, finite = TRUE) + + constants = ps(epsilon = p_dbl(lower = 0, default = 0)) + constants$values$epsilon = epsilon + + super$initialize("acq_ei", constants = constants, surrogate = surrogate, requires_predict_type_se = TRUE, direction = "maximize", label = "Expected Improvement", man = "mlr3mbo::mlr_acqfunctions_ei") }, #' @description @@ -73,14 +86,16 @@ AcqFunctionEI = R6Class("AcqFunctionEI", ), private = list( - .fun = function(xdt) { + .fun = function(xdt, ...) { if (is.null(self$y_best)) { stop("$y_best is not set. Missed to call $update()?") } + constants = list(...) + epsilon = constants$epsilon p = self$surrogate$predict(xdt) mu = p$mean se = p$se - d = self$y_best - self$surrogate_max_to_min * mu + d = (self$y_best - self$surrogate_max_to_min * mu) - epsilon d_norm = d / se ei = d * pnorm(d_norm) + se * dnorm(d_norm) ei = ifelse(se < 1e-20, 0, ei) diff --git a/man/mlr_acqfunctions_ei.Rd b/man/mlr_acqfunctions_ei.Rd index 700e72ab..4357d844 100644 --- a/man/mlr_acqfunctions_ei.Rd +++ b/man/mlr_acqfunctions_ei.Rd @@ -17,6 +17,17 @@ acqf("ei") }\if{html}{\out{}} } +\section{Parameters}{ + +\itemize{ +\item \code{"epsilon"} (\code{numeric(1)})\cr +\eqn{\epsilon} value used to determine the amount of exploration. +Higher values result in the importance of improvements predicted by the posterior mean +decreasing relative to the importance of potential improvements in regions of high predictive uncertainty. +Defaults to \code{0} (standard Expected Improvement). +} +} + \examples{ if (requireNamespace("mlr3learners") & requireNamespace("DiceKriging") & @@ -110,13 +121,15 @@ In the case of maximization, this already includes the necessary change of sign. \subsection{Method \code{new()}}{ Creates a new instance of this \link[R6:R6Class]{R6} class. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{AcqFunctionEI$new(surrogate = NULL)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{AcqFunctionEI$new(surrogate = NULL, epsilon = 0)}\if{html}{\out{
}} } \subsection{Arguments}{ \if{html}{\out{
}} \describe{ \item{\code{surrogate}}{(\code{NULL} | \link{SurrogateLearner}).} + +\item{\code{epsilon}}{(\code{numeric(1)}).} } \if{html}{\out{
}} } diff --git a/tests/testthat/test_AcqFunctionCB.R b/tests/testthat/test_AcqFunctionCB.R index 1749b3d2..eecaa100 100644 --- a/tests/testthat/test_AcqFunctionCB.R +++ b/tests/testthat/test_AcqFunctionCB.R @@ -12,6 +12,9 @@ test_that("AcqFunctionCB works", { expect_learner(acqf$surrogate$learner) expect_true(acqf$requires_predict_type_se) + expect_r6(acqf$constants, "ParamSet") + expect_equal(acqf$constants$ids(), "lambda") + design = MAKE_DESIGN(inst) inst$eval_batch(design) diff --git a/tests/testthat/test_AcqFunctionEHVIGH.R b/tests/testthat/test_AcqFunctionEHVIGH.R index 801fd841..58273cca 100644 --- a/tests/testthat/test_AcqFunctionEHVIGH.R +++ b/tests/testthat/test_AcqFunctionEHVIGH.R @@ -15,6 +15,9 @@ test_that("AcqFunctionEHVIGH works", { expect_true(acqf$requires_predict_type_se) expect_setequal(acqf$packages, c("emoa", "fastGHQuad")) + expect_r6(acqf$constants, "ParamSet") + expect_equal(acqf$constants$ids(), c("k", "r")) + design = MAKE_DESIGN(inst) inst$eval_batch(design) diff --git a/tests/testthat/test_AcqFunctionEI.R b/tests/testthat/test_AcqFunctionEI.R index 59eae9ef..dd8be11a 100644 --- a/tests/testthat/test_AcqFunctionEI.R +++ b/tests/testthat/test_AcqFunctionEI.R @@ -13,6 +13,9 @@ test_that("AcqFunctionEI works", { expect_learner(acqf$surrogate$learner) expect_true(acqf$requires_predict_type_se) + expect_r6(acqf$constants, "ParamSet") + expect_equal(acqf$constants$ids(), "epsilon") + design = MAKE_DESIGN(inst) inst$eval_batch(design) diff --git a/tests/testthat/test_AcqFunctionSmsEgo.R b/tests/testthat/test_AcqFunctionSmsEgo.R index 020a946f..b6cdece9 100644 --- a/tests/testthat/test_AcqFunctionSmsEgo.R +++ b/tests/testthat/test_AcqFunctionSmsEgo.R @@ -12,6 +12,9 @@ test_that("AcqFunctionSmsEgo works", { expect_list(acqf$surrogate$learner, types = "Learner") expect_true(acqf$requires_predict_type_se) + expect_r6(acqf$constants, "ParamSet") + expect_equal(acqf$constants$ids(), c("lambda", "epsilon")) + design = MAKE_DESIGN(inst) inst$eval_batch(design)