Skip to content

Commit

Permalink
feat: add AcqFunctionMulti that can wrap multiple acquisition functio…
Browse files Browse the repository at this point in the history
…ns resulting in a multi-objective acquisition function problem

feat: adjusted AcqOptimizer to be more robust (get_best) functionality but also handle AcqFunctionMulti
  • Loading branch information
sumny committed Aug 19, 2024
1 parent eb1be03 commit 53473c4
Show file tree
Hide file tree
Showing 27 changed files with 593 additions and 113 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ Collate:
'AcqFunctionEI.R'
'AcqFunctionEIPS.R'
'AcqFunctionMean.R'
'AcqFunctionMulti.R'
'AcqFunctionPI.R'
'AcqFunctionSD.R'
'AcqFunctionSmsEgo.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export(AcqFunctionEHVIGH)
export(AcqFunctionEI)
export(AcqFunctionEIPS)
export(AcqFunctionMean)
export(AcqFunctionMulti)
export(AcqFunctionPI)
export(AcqFunctionSD)
export(AcqFunctionSmsEgo)
Expand All @@ -25,6 +26,7 @@ export(SurrogateLearner)
export(SurrogateLearnerCollection)
export(TunerMbo)
export(acqf)
export(acqfs)
export(acqo)
export(bayesopt_ego)
export(bayesopt_emo)
Expand Down
172 changes: 118 additions & 54 deletions R/AcqFunctionMulti.R
Original file line number Diff line number Diff line change
@@ -1,22 +1,84 @@
#' @title Acquisition Function Wrapping Multiple Acquisition Functions
#'
#' @include AcqFunction.R
#' @name mlr_acqfunctions_multi
#'
#' @templateVar id multi
#' @template section_dictionary_acqfunctions
#'
#' @description
#' Wrapping multiple [AcqFunction]s resulting in a multi-objective acquisition function composed of the individual ones.
#' Note that the optimization direction of each wrapped acquisition function is corrected for maximization.
#'
#' @family Acquisition Function
#' @export
#' @examples
#' if (requireNamespace("mlr3learners") &
#' requireNamespace("DiceKriging") &
#' requireNamespace("rgenoud")) {
#' library(bbotk)
#' library(paradox)
#' library(mlr3learners)
#' library(data.table)
#'
#' fun = function(xs) {
#' list(y = xs$x ^ 2)
#' }
#' domain = ps(x = p_dbl(lower = -10, upper = 10))
#' codomain = ps(y = p_dbl(tags = "minimize"))
#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)
#'
#' instance = OptimInstanceBatchSingleCrit$new(
#' objective = objective,
#' terminator = trm("evals", n_evals = 5))
#'
#' instance$eval_batch(data.table(x = c(-6, -5, 3, 9)))
#'
#' learner = default_gp()
#'
#' surrogate = srlrn(learner, archive = instance$archive)
#'
#' acq_function = acqf("multi",
#' acq_functions = acqfs(c("ei", "pi", "cb")),
#' surrogate = surrogate
#' )
#'
#' acq_function$surrogate$update()
#' acq_function$update()
#' acq_function$eval_dt(data.table(x = c(-1, 0, 1)))
#' }
AcqFunctionMulti = R6Class("AcqFunctionMulti",
inherit = AcqFunction,

public = list(

initialize = function(acq_functions) {
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @param acq_functions (list of [AcqFunction]s).
#' @param surrogate (`NULL` | [Surrogate]).
initialize = function(acq_functions, surrogate = NULL) {
assert_list(acq_functions, "AcqFunction", min.len = 2L)
acq_function_ids = map_chr(acq_functions, function(acq_function) acq_function$id)
assert_character(acq_function_ids, unique = TRUE)
acq_functions = setNames(acq_functions, nm = acq_function_ids)
acq_function_directions = map_chr(acq_functions, function(acq_function) acq_function$direction)
private$.acq_functions = acq_functions
# FIXME: check for unique ids
id = paste0(c("acq", map_chr(acq_functions, function(acq_function) gsub("acq_", replacement = "", x = acq_function$id))), collapse = "_")
private$.acq_function_ids = acq_function_ids
private$.acq_function_directions = acq_function_directions
id = paste0(c("acq", map_chr(acq_function_ids, function(id) gsub("acq_", replacement = "", x = id))), collapse = "_")
label = paste0("Multi Acquisition Function of ", paste0(map_chr(acq_functions, function(acq_function) acq_function$label), collapse = ", "))
# FIXME: constant ids must be prefixed by acqf_id
constants = do.call(c, map(acq_functions, function(acq_function) acq_function$constants))
constants = ps()
domains = map(acq_functions, function(acq_function) acq_function$domain)
assert_true(all(map_lgl(domains[-1L], function(domain) all.equal(domains[[1L]]$data, domain$data))))
# FIXME: surrogates could be the same or different, how to handle this with $update() of the surrogate?
surrogates = map(acq_functions, function(acq_function) acq_function$surrogate)
assert_true(length(unique(map_chr(surrogates, function(surrogate) address(surrogate)))) == 1L)
surrogate = surrogates[[1L]]
if (is.null(surrogate)) {
surrogates = map(acq_functions, function(acq_function) acq_function$surrogate)
assert_list(surrogates, types = c("Surrogate", "NULL"))
if (length(unique(map_chr(surrogates, function(surrogate) address(surrogate)))) > 1L) {
stop("Acquisition functions must rely on the same surrogate model.")
}
surrogate = surrogates[[1L]]
}
requires_predict_type_se = any(map_lgl(acq_functions, function(acq_function) acq_function$requires_predict_type_se))
packages = unique(unlist(map(acq_functions, function(acq_function) acq_function$packages)))
properties = character()
Expand All @@ -25,7 +87,7 @@ AcqFunctionMulti = R6Class("AcqFunctionMulti",

private$.requires_predict_type_se = requires_predict_type_se
private$.packages = packages
self$direction = "minimize"
self$direction = "maximize"
if (is.null(surrogate)) {
domain = ParamSet$new()
codomain = ParamSet$new()
Expand All @@ -34,7 +96,10 @@ AcqFunctionMulti = R6Class("AcqFunctionMulti",
stopf("Acquisition function '%s' requires the surrogate to have `\"se\"` as `$predict_type`.", sprintf("<%s:%s>", "AcqFunction", id))
}
private$.surrogate = surrogate
private$.archive = assert_r6(surrogate$archive, classes = "Archive")
private$.archive = assert_archive(surrogate$archive)
for (acq_function in private$.acq_functions) {
acq_function$surrogate = surrogate
}
codomain = generate_acq_multi_codomain(surrogate, acq_functions = acq_functions)
self$surrogate_max_to_min = surrogate_mult_max_to_min(surrogate)
domain = generate_acq_domain(surrogate)
Expand All @@ -59,7 +124,10 @@ AcqFunctionMulti = R6Class("AcqFunctionMulti",
#' @description
#' Update each of the wrapped acquisition functions.
update = function() {
for (acq_function in private$.acq_functions) {
if (length(unique(map_chr(self$acq_functions, function(acq_function) address(acq_function$surrogate)))) > 1L) {
stop("Acquisition functions must rely on the same surrogate model.")
}
for (acq_function in self$acq_functions) {
acq_function$update()
}
}
Expand All @@ -72,75 +140,71 @@ AcqFunctionMulti = R6Class("AcqFunctionMulti",
if (missing(rhs)) {
private$.surrogate
} else {
# FIXME: assign surrogate to all acqfs?
assert_r6(rhs, classes = "Surrogate")
if (self$requires_predict_type_se && rhs$predict_type != "se") {
stopf("Acquisition function '%s' requires the surrogate to have `\"se\"` as `$predict_type`.", format(self))
}
private$.surrogate = rhs
private$.archive = assert_archive(rhs$archive)
codomain = generate_acq_multi_codomain(surrogate, acq_functions = acq_functions)
for (acq_function in self$acq_functions) {
acq_function$surrogate = rhs
}
codomain = generate_acq_multi_codomain(rhs, acq_functions = self$acq_functions)
self$surrogate_max_to_min = surrogate_mult_max_to_min(rhs)
domain = generate_acq_domain(rhs)
# lazy initialization requires this:
self$codomain = Codomain$new(get0("domains", codomain, ifnotfound = codomain$params)) # get0 for old paradox
self$domain = domain
}
},

#' @field acq_functions (list of [AcqFunction])\cr
#' Points to the list of the individual acqusition functions.
acq_functions = function(rhs) {
if (!missing(rhs) && !identical(rhs, private$.acq_functions)) {
stop("$acq_functions is read-only.")
}
private$.acq_functions
},

#' @field acq_function_ids (character())\cr
#' Points to the ids of the individual acqusition functions.
acq_function_ids = function(rhs) {
if (!missing(rhs) && !identical(rhs, private$.acq_function_ids)) {
stop("$acq_function_ids is read-only.")
}
private$.acq_function_ids
}
),

private = list(
.acq_functions = NULL,

.fun = function(xdt, ...) {
constants = list(...)
# FIXME: prefixed constants matching; needed at all?
values = map_dtc(private$.acq_functions, function(acq_function) acq_function$eval_dt(xdt))
ids = map_chr(private$.acq_functions, function(acq_function) acq_function$id)
directions = map_chr(private$.acq_functions, function(acq_function) acq_function$direction)
.acq_function_ids = NULL,

.acq_function_directions = NULL,

# NOTE: this is currently slower than it could be because when each acquisition functions is evaluated,
# the mean and se prediction for each point is computed again using the surrogate of that acquisition function,
# however, as acquisition functions must share the same surrogate, this is redundant.
# It might be sensible to have a customized eval function for acquisition functions where directly the mean and se
# predictions are passed (along xdt) so that one can save computing the mean and se predictions over and over again.
# This also would, however, depend on learners being fully deterministic.
.fun = function(xdt) {
values = map_dtc(self$acq_functions, function(acq_function) acq_function$eval_dt(xdt))
ids = private$.acq_function_ids
directions = private$.acq_function_directions
if (any(directions == "same")) {
directions[directions == "same"] = self$surrogate$archive$codomain$tags[[1L]]
}
change_sign = ids[directions == "maximize"]
change_sign = ids[directions == "minimize"]
for (j in change_sign) {
set(values, j = j, value = - values[[j]])
}
# FIXME: standardize column ranges to [0, 1]
values
}
)
)

# FIXME: test with multi objective also?
# FIXME: currently there is overhead because each acqf predicts with the surrogate
# but if the surrogate is always the same and shared, can we save time by predicting for xdt
# and having for each acqf a variant of fun that directly uses mean and se

if (FALSE) {
fun = function(xs) {
list(y = xs$x ^ 2)
}
domain = ps(x = p_dbl(lower = -10, upper = 10))
codomain = ps(y = p_dbl(tags = "minimize"))
objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)

instance = OptimInstanceBatchSingleCrit$new(
objective = objective,
terminator = trm("evals", n_evals = 5))

instance$eval_batch(data.table(x = c(-6, -5, 3, 9)))

learner = default_gp()

surrogate = srlrn(learner, archive = instance$archive)

ei = acqf("ei", surrogate = surrogate)
lcb = acqf("cb", surrogate = surrogate)
pi = acqf("pi", surrogate = surrogate)

acqf = AcqFunctionMulti$new(list(ei, lcb, pi))
acqf$surrogate$update()
acqf$update()
acqf$eval_dt(data.table(x = c(-1, 0, 1)))
}
mlr_acqfunctions$add("multi", AcqFunctionMulti)

33 changes: 23 additions & 10 deletions R/AcqOptimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,23 +141,29 @@ AcqOptimizer = R6Class("AcqOptimizer",
#'
#' @return [data.table::data.table()] with 1 row per optimum and x as columns.
optimize = function() {
# FIXME: currently only supports singlecrit acquisition functions
if (self$acq_function$codomain$length > 1L) {
stop("Multi-objective acquisition functions are currently not supported.")
}
is_multi_acq_function = self$acq_function$codomain$length > 1L

logger = lgr::get_logger("bbotk")
old_threshold = logger$threshold
logger$set_threshold(self$param_set$values$logging_level)
on.exit(logger$set_threshold(old_threshold))

instance = OptimInstanceBatchSingleCrit$new(objective = self$acq_function, search_space = self$acq_function$domain, terminator = self$terminator, check_values = FALSE, callbacks = self$callbacks)
if (is_multi_acq_function) {
instance = OptimInstanceBatchMultiCrit$new(objective = self$acq_function, search_space = self$acq_function$domain, terminator = self$terminator, check_values = FALSE, callbacks = self$callbacks)
} else {
instance = OptimInstanceBatchSingleCrit$new(objective = self$acq_function, search_space = self$acq_function$domain, terminator = self$terminator, check_values = FALSE, callbacks = self$callbacks)
}

# warmstart
if (self$param_set$values$warmstart) {
warmstart_size = if (isTRUE(self$param_set$values$warmstart_size == "all")) Inf else self$param_set$values$warmstart_size %??% 1L # default is 1L
n_select = min(nrow(self$acq_function$archive$data), warmstart_size)
instance$eval_batch(self$acq_function$archive$best(n_select = n_select)[, instance$search_space$ids(), with = FALSE])
warmstart_xdt = if (is_multi_acq_function) {
self$acq_function$archive$nds_selection(n_select = n_select)[, instance$search_space$ids(), with = FALSE]
} else {
self$acq_function$archive$best(n_select = n_select)[, instance$search_space$ids(), with = FALSE]
}
instance$eval_batch(warmstart_xdt)
}

# acquisition function optimization
Expand All @@ -166,7 +172,7 @@ AcqOptimizer = R6Class("AcqOptimizer",
tryCatch(
{
self$optimizer$optimize(instance)
get_best_not_evaluated(instance, evaluated = self$acq_function$archive$data, n_select = self$param_set$values$n_candidates)
get_best(instance, is_multi_acq_function = is_multi_acq_function, evaluated = self$acq_function$archive$data, n_select = self$param_set$values$n_candidates, not_already_evaluated = TRUE)
}, error = function(error_condition) {
lg$warn(error_condition$message)
stop(set_class(list(message = error_condition$message, call = NULL),
Expand All @@ -175,14 +181,14 @@ AcqOptimizer = R6Class("AcqOptimizer",
)
} else {
self$optimizer$optimize(instance)
get_best_not_evaluated(instance, evaluated = self$acq_function$archive$data, n_select = self$param_set$values$n_candidates)
get_best(instance, is_multi_acq_function = is_multi_acq_function, evaluated = self$acq_function$archive$data, n_select = self$param_set$values$n_candidates, not_already_evaluated = TRUE)
}
} else {
if (self$param_set$values$catch_errors) {
tryCatch(
{
self$optimizer$optimize(instance)
instance$archive$best(n_select = self$param_set$values$n_candidates)
get_best(instance, is_multi_acq_function = is_multi_acq_function, evaluated = self$acq_function$archive$data, n_select = self$param_set$values$n_candidates, not_already_evaluated = FALSE)
}, error = function(error_condition) {
lg$warn(error_condition$message)
stop(set_class(list(message = error_condition$message, call = NULL),
Expand All @@ -191,9 +197,16 @@ AcqOptimizer = R6Class("AcqOptimizer",
)
} else {
self$optimizer$optimize(instance)
instance$archive$best(n_select = self$param_set$values$n_candidates)
get_best(instance, is_multi_acq_function = is_multi_acq_function, evaluated = self$acq_function$archive$data, n_select = self$param_set$values$n_candidates, not_already_evaluated = FALSE)
}
}
#if (is_multi_acq_function) {
# set(xdt, j = instance$objective$id, value = apply(xdt[, instance$objective$acq_function_ids, with = FALSE], MARGIN = 1L, FUN = c, simplify = FALSE))
# for (acq_function_id in instance$objective$acq_function_ids) {
# set(xdt, j = acq_function_id, value = NULL)
# }
# setcolorder(xdt, c(instance$archive$cols_x, "x_domain", instance$objective$id))
#}
xdt
}
),
Expand Down
Loading

0 comments on commit 53473c4

Please sign in to comment.