Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
bpkroth committed Jan 22, 2025
1 parent 1dcfca2 commit ee0883c
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions mlos_bench/mlos_bench/schedulers/forking_worker_pool_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,37 +161,52 @@ def start_optimization_loop(self) -> None:
with Pool(processes=len(self._trial_runners), maxtasksperchild=1) as pool:
while not self.is_done_scheduling() or not self.is_done_running():
# Run any existing trials that aren't currently running.
# Do this first in case we're resuming from a previous run
# (e.g., the real system will have remembered which Trials were
# in progress by reloading them from the Storage backend).

# Avoid modifying the dictionary while iterating over it.
trial_schedule = self._trial_schedule.copy()
for trial_id, (runner_id, suggestion) in trial_schedule.items():
# Skip trials that are already running on their assigned TrialRunner.
if self._trial_runners_status[runner_id] is not None:
continue
# Else, start the Trial on the given TrialRunner in the background.
self._trial_runners_status[runner_id] = pool.apply_async(
TrialRunner(runner_id).run_trial,
args=(trial_id, suggestion),
callback=self._run_trial_finished_callback,
error_callback=self._run_trial_failed_callback,
)
# Now all the available TrialRunners that had work to do should be running.

# Wait a moment to check if we have any idle TrialRunners.
# This also allows us a chance to collect multiple results from
# the pool before suggesting new ones.
while len(self._trial_schedule) > 0 and self.get_idle_trial_runners_count() == 0:
# Make the polling interval here configurable.
sleep(0.5)

# Schedule more trials if we can.
self.schedule_new_trials(num_new_trials=self.get_idle_trial_runners_count() or 1)

# Should be all done starting new trials.
print("Closing the pool.", flush=True)
pool.close()
# FIXME: This sometimes hangs. Not sure why yet.

print("Waiting for all trials to finish.", flush=True)
# FIXME: This sometimes hangs. Not sure why yet.
pool.join()

print("Optimization loop is done.", flush=True)
print("results: " + json.dumps(self._results, indent=2))
print("trial_schedule: " + json.dumps(self._trial_schedule, indent=2))
print("trial_runner_status: " + json.dumps(self._trial_runners_status, indent=2))
assert len(self._results) == self._max_iterations
assert not self._trial_schedule
assert all(x is None for x in self._trial_runners_status.values())
assert len(self._results) == self._max_iterations, "Unexpected number of trials run."
assert not self._trial_schedule, "Some scheduled trials were not started."
assert all(
x is None for x in self._trial_runners_status.values()
), "Some TrialRunners are still running."


def main():
Expand Down

0 comments on commit ee0883c

Please sign in to comment.