Skip to content

Commit

Permalink
refactor: pass extra information of the result in the extra parameter (
Browse files Browse the repository at this point in the history
…#168)

* refactor: pass extra information of the result in the extra parameter

* remove remotes
  • Loading branch information
be-marc authored Nov 7, 2024
1 parent c249a51 commit ddbb5ef
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ BugReports: https://github.com/mlr-org/mlr3mbo/issues
Depends:
R (>= 3.1.0)
Imports:
bbotk (>= 1.1.1),
bbotk (>= 1.2.0),
checkmate (>= 2.0.0),
data.table,
lgr (>= 0.3.4),
Expand Down
15 changes: 10 additions & 5 deletions R/ResultAssignerArchive.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,19 @@ ResultAssignerArchive = R6Class("ResultAssignerArchive",
#' The [bbotk::OptimInstance] the final result should be assigned to.
assign_result = function(instance) {
xydt = instance$archive$best()
xdt = xydt[, instance$search_space$ids(), with = FALSE]
cols_x = instance$archive$cols_x
cols_y = instance$archive$cols_y

xdt = xydt[, cols_x, with = FALSE]
extra = xydt[, !c(cols_x, cols_y), with = FALSE]

if (inherits(instance, "OptimInstanceBatchMultiCrit")) {
ydt = xydt[, instance$archive$cols_y, with = FALSE]
instance$assign_result(xdt, ydt, xydt = xydt)
ydt = xydt[, cols_y, with = FALSE]
instance$assign_result(xdt, ydt, extra = extra)
}
else {
y = unlist(xydt[, instance$archive$cols_y, with = FALSE])
instance$assign_result(xdt, y, xydt = xydt)
y = unlist(xydt[, cols_y, with = FALSE])
instance$assign_result(xdt, y, extra = extra)
}
}
),
Expand Down
3 changes: 2 additions & 1 deletion R/ResultAssignerSurrogate.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ ResultAssignerSurrogate = R6Class("ResultAssignerSurrogate",
archive_tmp = archive$clone(deep = TRUE)
archive_tmp$data[, self$surrogate$cols_y := means]
xydt = archive_tmp$best()
extra = xydt[, !c(archive_tmp$cols_x, archive_tmp$cols_y), with = FALSE]
best = xydt[, archive_tmp$cols_x, with = FALSE]

# ys are still the ones originally evaluated
Expand All @@ -68,7 +69,7 @@ ResultAssignerSurrogate = R6Class("ResultAssignerSurrogate",
} else if (inherits(instance, "OptimInstanceBatchMultiCrit")) {
archive$data[best, on = archive$cols_x][, archive$cols_y, with = FALSE]
}
instance$assign_result(xdt = best, best_y, xydt = xydt)
instance$assign_result(xdt = best, best_y, extra = extra)
}
),

Expand Down

0 comments on commit ddbb5ef

Please sign in to comment.