Skip to content

Commit

Permalink
Allow delete_gs() to delete multiple with kwarg (#3292)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3292

Sometimes we get in a weird state with multiple GSs on an experiment and that was blocking deletion and loading

Reviewed By: andycylmeta

Differential Revision: D68905693

fbshipit-source-id: 167b9637c12bdf2e498521d32725555929379559
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Jan 31, 2025
1 parent 9bd4c16 commit 9028fcf
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 8 deletions.
37 changes: 30 additions & 7 deletions ax/storage/sqa_store/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,34 +34,57 @@ def delete_experiment(exp_name: str) -> None:
)


def delete_generation_strategy(exp_name: str, config: SQAConfig | None = None) -> None:
def delete_generation_strategy(
exp_name: str, config: SQAConfig | None = None, max_gs_to_delete: int = 1
) -> None:
"""Delete the generation strategy associated with an experiment
Args:
exp_name: Name of the experiment for which the generation strategy
should be deleted.
config: The SQAConfig.
max_gs_to_delete: There is never supposed to be more than one generation
strategy associated with an experiment. However, we've seen cases where
there are, and we don't know why. This parameter allows us to delete
multiple generation strategies, but we raise an error if there are more
than `max_gs_to_delete` generation strategies associated with the
experiment.
This is a safeguard in case we have a bug in this code that deletes
all generation strategies.
"""
config = config or SQAConfig()
decoder = Decoder(config=config)
exp_sqa_class = decoder.config.class_to_sqa_class[Experiment]
gs_sqa_class = decoder.config.class_to_sqa_class[GenerationStrategy]
# get the generation strategy's db_id
with session_scope() as session:
sqa_gs_id = (
sqa_gs_ids = (
session.query(gs_sqa_class.id) # pyre-ignore[16]
.join(exp_sqa_class.generation_strategy) # pyre-ignore[16]
# pyre-fixme[16]: `SQABase` has no attribute `name`.
.filter(exp_sqa_class.name == exp_name)
.one_or_none()
.all()
)

if sqa_gs_id is None:
if sqa_gs_ids is None:
logger.info(f"No generation strategy found for experiment {exp_name}.")
return None

gs_id = sqa_gs_id[0]
if len(sqa_gs_ids) > max_gs_to_delete:
raise ValueError(
f"Found {len(sqa_gs_ids)} generation strategies for experiment {exp_name}. "
"If you are sure you want to delete all of them, please set "
f"`max_gs_to_delete` (currently {max_gs_to_delete}) to a higher value."
)

# delete generation strategy
with session_scope() as session:
gs = session.query(gs_sqa_class).filter_by(id=gs_id).one_or_none()
session.delete(gs)
gs_list = (
session.query(gs_sqa_class)
.filter(gs_sqa_class.id.in_([id[0] for id in sqa_gs_ids]))
.all()
)
for gs in gs_list:
session.delete(gs)
session.flush()
54 changes: 53 additions & 1 deletion ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
session_scope,
)
from ax.storage.sqa_store.decoder import Decoder
from ax.storage.sqa_store.delete import delete_experiment
from ax.storage.sqa_store.delete import delete_experiment, delete_generation_strategy
from ax.storage.sqa_store.encoder import Encoder
from ax.storage.sqa_store.load import (
_get_experiment_immutable_opt_config_and_search_space,
Expand Down Expand Up @@ -2248,3 +2248,55 @@ def test_AnalysisCard(self) -> None:
loaded_analysis_cards[2].blob,
plotly_analysis_card.blob,
)

def test_delete_generation_strategy(self) -> None:
# GIVEN an experiment with a generation strategy
experiment = get_branin_experiment()
generation_strategy = choose_generation_strategy(experiment.search_space)
generation_strategy.experiment = experiment
save_experiment(experiment)
save_generation_strategy(generation_strategy=generation_strategy)

# AND GIVEN another experiment with a generation strategy
experiment2 = get_branin_experiment()
experiment2.name = "experiment2"
generation_strategy2 = choose_generation_strategy(experiment2.search_space)
generation_strategy2.experiment = experiment2
save_experiment(experiment2)
save_generation_strategy(generation_strategy=generation_strategy2)

# WHEN I delete the generation strategy
delete_generation_strategy(exp_name=experiment.name, max_gs_to_delete=2)

# THEN the generation strategy is deleted
with self.assertRaises(ObjectNotFoundError):
load_generation_strategy_by_experiment_name(experiment.name)

# AND the other generation strategy is still there
loaded_generation_strategy = load_generation_strategy_by_experiment_name(
experiment2.name
)
# Full GS fails the equality check
self.assertEqual(str(generation_strategy2), str(loaded_generation_strategy))

def test_delete_generation_strategy_max_gs_to_delete(self) -> None:
# GIVEN an experiment with a generation strategy
experiment = get_branin_experiment()
generation_strategy = choose_generation_strategy(experiment.search_space)
generation_strategy.experiment = experiment
save_experiment(experiment)
save_generation_strategy(generation_strategy=generation_strategy)

# WHEN I delete the generation strategy with max_gs_to_delete=0
with self.assertRaisesRegex(
ValueError,
"Found 1 generation strategies",
):
delete_generation_strategy(exp_name=experiment.name, max_gs_to_delete=0)

# THEN the generation strategy is not deleted
loaded_generation_strategy = load_generation_strategy_by_experiment_name(
experiment.name
)
# Full GS fails the equality check
self.assertEqual(str(generation_strategy), str(loaded_generation_strategy))

0 comments on commit 9028fcf

Please sign in to comment.