Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: bbotk async compatibility #146

Merged
merged 27 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/dev-cmd-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ jobs:
with:
r-version: ${{ matrix.config.r }}

- uses: supercharge/[email protected]
with:
redis-version: 7

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck
Expand All @@ -48,3 +52,6 @@ jobs:
shell: Rscript {0}

- uses: r-lib/actions/check-r-package@v2
with:
args: 'c("--no-manual")' # "--as-cran" prevents to start external processes

6 changes: 6 additions & 0 deletions .github/workflows/r-cmd-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,15 @@ jobs:
with:
r-version: ${{ matrix.config.r }}

- uses: supercharge/[email protected]
with:
redis-version: 7

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck
needs: check

- uses: r-lib/actions/check-r-package@v2
with:
args: 'c("--no-manual")' # "--as-cran" prevents to start external processes
9 changes: 6 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ BugReports: https://github.com/mlr-org/mlr3mbo/issues
Depends:
R (>= 3.1.0)
Imports:
bbotk (>= 0.5.4),
bbotk (>= 0.8.0.9000),
checkmate (>= 2.0.0),
data.table,
lgr (>= 0.3.4),
mlr3 (>= 0.14.0),
mlr3misc (>= 0.11.0),
mlr3tuning (>= 0.14.0),
mlr3tuning (>= 0.20.0.9000),
paradox (>= 0.10.0),
spacefillr,
R6 (>= 2.4.1)
Expand All @@ -65,7 +65,10 @@ Suggests:
rmarkdown,
rpart,
stringi,
testthat (>= 3.0.0),
testthat (>= 3.0.0)
Remotes:
mlr-org/bbotk,
mlr-org/mlr3tuning,
ByteCompile: no
Encoding: UTF-8
Config/testthat/edition: 3
Expand Down
2 changes: 1 addition & 1 deletion R/AcqFunction.R
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ AcqFunction = R6Class("AcqFunction",
stopf("Acquisition function '%s' requires the surrogate to have `\"se\"` as `$predict_type`.", format(self))
}
private$.surrogate = rhs
private$.archive = assert_r6(rhs$archive, classes = "Archive")
private$.archive = assert_archive(rhs$archive)
codomain = generate_acq_codomain(rhs, id = self$id, direction = self$direction)
self$surrogate_max_to_min = surrogate_mult_max_to_min(rhs)
domain = generate_acq_domain(rhs)
Expand Down
2 changes: 1 addition & 1 deletion R/AcqFunctionAEI.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
#' codomain = codomain,
#' properties = "noisy")
#'
#' instance = OptimInstanceSingleCrit$new(
#' instance = OptimInstanceBatchSingleCrit$new(
#' objective = objective,
#' terminator = trm("evals", n_evals = 5))
#'
Expand Down
2 changes: 1 addition & 1 deletion R/AcqFunctionCB.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
#' codomain = ps(y = p_dbl(tags = "minimize"))
#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)
#'
#' instance = OptimInstanceSingleCrit$new(
#' instance = OptimInstanceBatchSingleCrit$new(
#' objective = objective,
#' terminator = trm("evals", n_evals = 5))
#'
Expand Down
4 changes: 2 additions & 2 deletions R/AcqFunctionEHVI.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#' codomain = ps(y1 = p_dbl(tags = "minimize"), y2 = p_dbl(tags = "minimize"))
#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)
#'
#' instance = OptimInstanceMultiCrit$new(
#' instance = OptimInstanceBatchMultiCrit$new(
#' objective = objective,
#' terminator = trm("evals", n_evals = 5))
#'
Expand Down Expand Up @@ -117,7 +117,7 @@ AcqFunctionEHVI = R6Class("AcqFunctionEHVI",
}

columns = colnames(self$ys_front_augmented)

ps = self$surrogate$predict(xdt)
means = map_dtc(ps, "mean")

Expand Down
10 changes: 5 additions & 5 deletions R/AcqFunctionEHVIGH.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
#' codomain = ps(y1 = p_dbl(tags = "minimize"), y2 = p_dbl(tags = "minimize"))
#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)
#'
#' instance = OptimInstanceMultiCrit$new(
#' instance = OptimInstanceBatchMultiCrit$new(
#' objective = objective,
#' terminator = trm("evals", n_evals = 5))
#'
Expand Down Expand Up @@ -175,20 +175,20 @@ adjust_gh_data = function(gh_data, mu, sigma, r) {
idx = as.matrix(expand.grid(rep(list(1:n), n_obj)))
nodes = matrix(gh_data[idx, 1L], nrow = nrow(idx), ncol = n_obj)
weights = apply(matrix(gh_data[idx, 2L], nrow = nrow(idx), ncol = n_obj), MARGIN = 1L, FUN = prod)
# pruning with pruning rate r

# pruning with pruning rate r
if (r > 0) {
weights_quantile = quantile(weights, probs = r)
nodes = nodes[weights > weights_quantile, ]
weights = weights[weights > weights_quantile]
}

# rotate, scale, translate nodes with error catching
# rotation will not have an effect unless we support surrogate models modelling correlated objectives
# for now we still support this more general case and scaling is useful anyways
nodes = tryCatch(
{
eigen_decomp = eigen(sigma)
eigen_decomp = eigen(sigma)
rotation = eigen_decomp$vectors %*% diag(sqrt(eigen_decomp$values))
nodes = t(rotation %*% t(nodes) + mu)
}, error = function(ec) nodes
Expand Down
2 changes: 1 addition & 1 deletion R/AcqFunctionEI.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#' codomain = ps(y = p_dbl(tags = "minimize"))
#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)
#'
#' instance = OptimInstanceSingleCrit$new(
#' instance = OptimInstanceBatchSingleCrit$new(
#' objective = objective,
#' terminator = trm("evals", n_evals = 5))
#'
Expand Down
6 changes: 3 additions & 3 deletions R/AcqFunctionEIPS.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
#' @description
#' Expected Improvement per Second.
#'
#' It is assumed that calculations are performed on an [bbotk::OptimInstanceSingleCrit].
#' It is assumed that calculations are performed on an [bbotk::OptimInstanceBatchSingleCrit].
#' Additionally to target values of the codomain that should be minimized or maximized, the
#' [bbotk::Objective] of the [bbotk::OptimInstanceSingleCrit] should return time values.
#' [bbotk::Objective] of the [bbotk::OptimInstanceBatchSingleCrit] should return time values.
#' The column names of the target variable and time variable must be passed as `cols_y` in the
#' order `(target, time)` when constructing the [SurrogateLearnerCollection] that is being used as a
#' surrogate.
Expand All @@ -37,7 +37,7 @@
#' codomain = ps(y = p_dbl(tags = "minimize"), time = p_dbl(tags = "time"))
#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)
#'
#' instance = OptimInstanceSingleCrit$new(
#' instance = OptimInstanceBatchSingleCrit$new(
#' objective = objective,
#' terminator = trm("evals", n_evals = 5))
#'
Expand Down
2 changes: 1 addition & 1 deletion R/AcqFunctionMean.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#' codomain = ps(y = p_dbl(tags = "minimize"))
#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)
#'
#' instance = OptimInstanceSingleCrit$new(
#' instance = OptimInstanceBatchSingleCrit$new(
#' objective = objective,
#' terminator = trm("evals", n_evals = 5))
#'
Expand Down
2 changes: 1 addition & 1 deletion R/AcqFunctionPI.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#' codomain = ps(y = p_dbl(tags = "minimize"))
#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)
#'
#' instance = OptimInstanceSingleCrit$new(
#' instance = OptimInstanceBatchSingleCrit$new(
#' objective = objective,
#' terminator = trm("evals", n_evals = 5))
#'
Expand Down
2 changes: 1 addition & 1 deletion R/AcqFunctionSD.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#' codomain = ps(y = p_dbl(tags = "minimize"))
#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)
#'
#' instance = OptimInstanceSingleCrit$new(
#' instance = OptimInstanceBatchSingleCrit$new(
#' objective = objective,
#' terminator = trm("evals", n_evals = 5))
#'
Expand Down
2 changes: 1 addition & 1 deletion R/AcqFunctionSmsEgo.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
#' codomain = ps(y1 = p_dbl(tags = "minimize"), y2 = p_dbl(tags = "minimize"))
#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)
#'
#' instance = OptimInstanceMultiCrit$new(
#' instance = OptimInstanceBatchMultiCrit$new(
#' objective = objective,
#' terminator = trm("evals", n_evals = 5))
#'
Expand Down
4 changes: 2 additions & 2 deletions R/AcqOptimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
#' codomain = ps(y = p_dbl(tags = "minimize"))
#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)
#'
#' instance = OptimInstanceSingleCrit$new(
#' instance = OptimInstanceBatchSingleCrit$new(
#' objective = objective,
#' terminator = trm("evals", n_evals = 5))
#'
Expand Down Expand Up @@ -146,7 +146,7 @@ AcqOptimizer = R6Class("AcqOptimizer",
logger$set_threshold(self$param_set$values$logging_level)
on.exit(logger$set_threshold(old_threshold))

instance = OptimInstanceSingleCrit$new(objective = self$acq_function, search_space = self$acq_function$domain, terminator = self$terminator, check_values = FALSE, keep_evals = "all")
instance = OptimInstanceBatchSingleCrit$new(objective = self$acq_function, search_space = self$acq_function$domain, terminator = self$terminator, check_values = FALSE)

# warmstart
if (self$param_set$values$warmstart) {
Expand Down
6 changes: 3 additions & 3 deletions R/OptimizerMbo.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
#' codomain = ps(y = p_dbl(tags = "minimize"))
#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)
#'
#' instance = OptimInstanceSingleCrit$new(
#' instance = OptimInstanceBatchSingleCrit$new(
#' objective = objective,
#' terminator = trm("evals", n_evals = 5))
#'
Expand Down Expand Up @@ -75,7 +75,7 @@
#' codomain = ps(y1 = p_dbl(tags = "minimize"), y2 = p_dbl(tags = "minimize"))
#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)
#'
#' instance = OptimInstanceMultiCrit$new(
#' instance = OptimInstanceBatchMultiCrit$new(
#' objective = objective,
#' terminator = trm("evals", n_evals = 5))
#'
Expand All @@ -89,7 +89,7 @@
#' }
#' }
OptimizerMbo = R6Class("OptimizerMbo",
inherit = bbotk::Optimizer,
inherit = bbotk::OptimizerBatch,

public = list(
#' @description
Expand Down
2 changes: 1 addition & 1 deletion R/ResultAssigner.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ ResultAssigner = R6Class("ResultAssigner",
#' @description
#' Assigns the result, i.e., the final point(s) to the instance.
#'
#' @param instance ([bbotk::OptimInstanceSingleCrit] | [bbotk::OptimInstanceMultiCrit])\cr
#' @param instance ([bbotk::OptimInstanceBatchSingleCrit] | [bbotk::OptimInstanceBatchMultiCrit])\cr
#' The [bbotk::OptimInstance] the final result should be assigned to.
assign_result = function(instance) {
stop("Abstract.")
Expand Down
4 changes: 2 additions & 2 deletions R/ResultAssignerArchive.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ ResultAssignerArchive = R6Class("ResultAssignerArchive",
#' @description
#' Assigns the result, i.e., the final point(s) to the instance.
#'
#' @param instance ([bbotk::OptimInstanceSingleCrit] | [bbotk::OptimInstanceMultiCrit])\cr
#' @param instance ([bbotk::OptimInstanceBatchSingleCrit] | [bbotk::OptimInstanceBatchMultiCrit])\cr
#' The [bbotk::OptimInstance] the final result should be assigned to.
assign_result = function(instance) {
res = instance$archive$best()
xdt = res[, instance$search_space$ids(), with = FALSE]
if (inherits(instance, "OptimInstanceMultiCrit")) {
if (inherits(instance, "OptimInstanceBatchMultiCrit")) {
ydt = res[, instance$archive$cols_y, with = FALSE]
instance$assign_result(xdt, ydt)
}
Expand Down
12 changes: 6 additions & 6 deletions R/ResultAssignerSurrogate.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#' Result assigner that chooses the final point(s) based on a surrogate mean prediction of all evaluated points in the [bbotk::Archive].
#' This is especially useful in the case of noisy objective functions.
#'
#' In the case of operating on an [bbotk::OptimInstanceMultiCrit] the [SurrogateLearnerCollection] must use as many learners as there are objective functions.
#' In the case of operating on an [bbotk::OptimInstanceBatchMultiCrit] the [SurrogateLearnerCollection] must use as many learners as there are objective functions.
#'
#' @family Result Assigner
#' @export
Expand All @@ -32,15 +32,15 @@ ResultAssignerSurrogate = R6Class("ResultAssignerSurrogate",
#' Assigns the result, i.e., the final point(s) to the instance.
#' If `$surrogate` is `NULL`, `default_surrogate(instance)` is used and also assigned to `$surrogate`.
#'
#' @param instance ([bbotk::OptimInstanceSingleCrit] | [bbotk::OptimInstanceMultiCrit])\cr
#' @param instance ([bbotk::OptimInstanceBatchSingleCrit] | [bbotk::OptimInstanceBatchMultiCrit])\cr
#' The [bbotk::OptimInstance] the final result should be assigned to.
assign_result = function(instance) {
if (is.null(self$surrogate)) {
self$surrogate = default_surrogate(instance)
}
if (inherits(instance, "OptimInstanceSingleCrit")) {
if (inherits(instance, "OptimInstanceBatchSingleCrit")) {
assert_r6(self$surrogate, classes = "SurrogateLearner")
} else if (inherits(instance, "OptimInstanceMultiCrit")) {
} else if (inherits(instance, "OptimInstanceBatchMultiCrit")) {
assert_r6(self$surrogate, classes = "SurrogateLearnerCollection")
if (self$surrogate$n_learner != instance$objective$ydim) {
stopf("Surrogate used within the result assigner uses %i learners but the optimization instance has %i objective functions", self$surrogate$n_learner, instance$objective$ydim)
Expand All @@ -62,9 +62,9 @@ ResultAssignerSurrogate = R6Class("ResultAssignerSurrogate",
best = archive_tmp$best()[, archive_tmp$cols_x, with = FALSE]

# ys are still the ones originally evaluated
best_y = if (inherits(instance, "OptimInstanceSingleCrit")) {
best_y = if (inherits(instance, "OptimInstanceBatchSingleCrit")) {
unlist(archive$data[best, on = archive$cols_x][, archive$cols_y, with = FALSE])
} else if (inherits(instance, "OptimInstanceMultiCrit")) {
} else if (inherits(instance, "OptimInstanceBatchMultiCrit")) {
archive$data[best, on = archive$cols_x][, archive$cols_y, with = FALSE]
}
instance$assign_result(xdt = best, best_y)
Expand Down
2 changes: 1 addition & 1 deletion R/Surrogate.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ Surrogate = R6Class("Surrogate",
if (missing(rhs)) {
private$.archive
} else {
private$.archive = assert_r6(rhs, classes = "Archive")
private$.archive = assert_archive(rhs, null_ok = TRUE)
invisible(private$.archive)
}
},
Expand Down
2 changes: 1 addition & 1 deletion R/SurrogateLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
#' codomain = ps(y = p_dbl(tags = "minimize"))
#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)
#'
#' instance = OptimInstanceSingleCrit$new(
#' instance = OptimInstanceBatchSingleCrit$new(
#' objective = objective,
#' terminator = trm("evals", n_evals = 5))
#'
Expand Down
2 changes: 1 addition & 1 deletion R/SurrogateLearnerCollection.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
#' codomain = ps(y1 = p_dbl(tags = "minimize"), y2 = p_dbl(tags = "minimize"))
#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)
#'
#' instance = OptimInstanceMultiCrit$new(
#' instance = OptimInstanceBatchMultiCrit$new(
#' objective = objective,
#' terminator = trm("evals", n_evals = 5))
#' xdt = generate_design_random(instance$search_space, n = 4)$data
Expand Down
8 changes: 4 additions & 4 deletions R/TunerMbo.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#' @title Tuner using Model Based Optimization
#' @title TunerBatch using Model Based Optimization
#'
#' @name mlr_tuners_mbo
#'
Expand All @@ -23,7 +23,7 @@
#' resampling = rsmp("cv", folds = 3)
#' measure = msr("classif.acc")
#'
#' instance = TuningInstanceSingleCrit$new(
#' instance = TuningInstanceBatchSingleCrit$new(
#' task = task,
#' learner = learner,
#' resampling = resampling,
Expand All @@ -38,7 +38,7 @@
#' resampling = rsmp("cv", folds = 3)
#' measures = msrs(c("classif.acc", "selected_features"))
#'
#' instance = TuningInstanceMultiCrit$new(
#' instance = TuningInstanceBatchMultiCrit$new(
#' task = task,
#' learner = learner,
#' resampling = resampling,
Expand All @@ -50,7 +50,7 @@
#' }
#' }
TunerMbo = R6Class("TunerMbo",
inherit = mlr3tuning::TunerFromOptimizer,
inherit = mlr3tuning::TunerBatchFromOptimizerBatch,

public = list(
#' @description
Expand Down
Loading
Loading