From 0dcd8c0f6b14ab6f36e4259075268bdcca051383 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Tue, 21 Jan 2025 14:22:25 -0600 Subject: [PATCH] more enum fixups --- mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py | 4 ++-- .../tests/tunables/tunable_to_configspace_test.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py index fc2b9a7abb4..0a3d7a82d00 100644 --- a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py @@ -187,8 +187,8 @@ def _to_df(self, configs: Sequence[dict[str, TunableValue]]) -> pd.DataFrame: (special_name, type_name) = special_param_names(tunable.name) tunables_names += [special_name, type_name] is_special = df_configs[tunable.name].apply(tunable.special.__contains__) - df_configs[type_name] = TunableValueKind.RANGE - df_configs.loc[is_special, type_name] = TunableValueKind.SPECIAL + df_configs[type_name] = TunableValueKind.RANGE.value + df_configs.loc[is_special, type_name] = TunableValueKind.SPECIAL.value if tunable.type == "int": # Make int column NULLABLE: df_configs[tunable.name] = df_configs[tunable.name].astype("Int64") diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py index 55bc1301226..0c95f5a8d61 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py @@ -87,9 +87,9 @@ def configuration_space() -> ConfigurationSpace: ), CategoricalHyperparameter( name=kernel_sched_migration_cost_ns_type, - choices=[TunableValueKind.SPECIAL, TunableValueKind.RANGE], + choices=[TunableValueKind.SPECIAL.value, TunableValueKind.RANGE.value], weights=[0.5, 0.5], - default_value=TunableValueKind.SPECIAL, + default_value=TunableValueKind.SPECIAL.value, ), ] ) @@ -98,12 +98,12 @@ def configuration_space() -> ConfigurationSpace: EqualsCondition( spaces[kernel_sched_migration_cost_ns_special], spaces[kernel_sched_migration_cost_ns_type], - TunableValueKind.SPECIAL, + TunableValueKind.SPECIAL.value, ), EqualsCondition( spaces["kernel_sched_migration_cost_ns"], spaces[kernel_sched_migration_cost_ns_type], - TunableValueKind.RANGE, + TunableValueKind.RANGE.value, ), ] )