diff --git a/metagpt/ext/sela/runner/mcts.py b/metagpt/ext/sela/runner/mcts.py index 8b6c14100..63e8f82f0 100644 --- a/metagpt/ext/sela/runner/mcts.py +++ b/metagpt/ext/sela/runner/mcts.py @@ -12,7 +12,7 @@ class MCTSRunner(Runner): result_path: str = "results/mcts" - def __init__(self, args, tree_mode=None, **kwargs): + def __init__(self, args, data_config=None, tree_mode=None, **kwargs): if args.special_instruction == "image": self.start_task_id = 1 # start from datapreprocessing if it is image task else: @@ -23,7 +23,7 @@ def __init__(self, args, tree_mode=None, **kwargs): elif args.eval_func == "mlebench": self.eval_func = node_evaluate_score_mlebench - super().__init__(args, **kwargs) + super().__init__(args, data_config=None, **kwargs) self.tree_mode = tree_mode async def run_experiment(self): diff --git a/metagpt/ext/sela/runner/runner.py b/metagpt/ext/sela/runner/runner.py index 4b5504e09..ac28f0cd7 100644 --- a/metagpt/ext/sela/runner/runner.py +++ b/metagpt/ext/sela/runner/runner.py @@ -16,14 +16,14 @@ class Runner: data_config = DATA_CONFIG start_task_id = 1 - def __init__(self, args, **kwargs): + def __init__(self, args, data_config=None,**kwargs): self.args = args self.start_time_raw = datetime.datetime.now() self.start_time = self.start_time_raw.strftime("%Y%m%d%H%M") self.state = create_initial_state( self.args.task, start_task_id=self.start_task_id, - data_config=self.data_config, + data_config=data_config if data_config is not None else self.data_config, args=self.args, )