diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index 09341c947..6f460d2bb 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -41,7 +41,7 @@ @dataclass -class CustomEvaluationTask: +class LightevalTaskConfig: name: str prompt_function: str hf_repo: str @@ -95,7 +95,7 @@ def __post_init__(self): class LightevalTask: - def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom_tasks_module=None): + def __init__(self, name: str, cfg: LightevalTaskConfig, cache_dir: Optional[str] = None, custom_tasks_module=None): """ Initialize a LightEval task. @@ -115,8 +115,8 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom self._cfg = cfg # Dataset info - self.hf_repo = cfg["hf_repo"] - self.hf_subset = cfg["hf_subset"] + self.hf_repo = cfg.hf_repo + self.hf_subset = cfg.hf_subset self.dataset_path = self.hf_repo self.dataset_config_name = self.hf_subset self.dataset = None # Delayed download @@ -125,22 +125,22 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom self._docs = None # Managing splits and few shot - self.all_available_splits = as_list(cfg["hf_avail_splits"]) + self.all_available_splits = as_list(cfg.hf_avail_splits) if cfg.get("evaluation_splits", None) is None: raise ValueError(f"The evaluation split for task {self.name} is None. Please select a valid split.") - self.evaluation_split = as_list(cfg["evaluation_splits"]) + self.evaluation_split = as_list(cfg.evaluation_splits) if cfg.get("few_shots_split", None) is not None: - self.fewshot_split = as_list(cfg["few_shots_split"]) + self.fewshot_split = as_list(cfg.few_shots_split) else: self.fewshot_split = as_list(self.get_first_possible_fewshot_splits()) self.fewshot_sampler = FewShotSampler( - few_shots_select=cfg["few_shots_select"], few_shots_split=self.fewshot_split + few_shots_select=cfg.few_shots_select, few_shots_split=self.fewshot_split ) # Metrics - self.metrics = as_list(cfg["metric"]) - self.suite = as_list(cfg["suite"]) + self.metrics = as_list(cfg.metric) + self.suite = as_list(cfg.suite) ignored = [metric for metric in self.metrics if Metrics[metric].value.category == MetricCategory.IGNORED] if len(ignored) > 0: hlog_warn(f"[WARNING] Not implemented yet: ignoring the metric {' ,'.join(ignored)} for task {self.name}.") @@ -150,20 +150,20 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom # Data processing # to use once prompt formatting is managed as a module if custom_tasks_module is None: - self.formatter = getattr(tasks_prompt_formatting, cfg["prompt_function"]) - elif hasattr(custom_tasks_module, cfg["prompt_function"]): + self.formatter = getattr(tasks_prompt_formatting, cfg.prompt_function) + elif hasattr(custom_tasks_module, cfg.prompt_function): # If we have a prompt in both the custom_tasks_module and our tasks_prompt_formatting # We take the prompt from the custom_tasks_module - if hasattr(tasks_prompt_formatting, cfg["prompt_function"]): + if hasattr(tasks_prompt_formatting, cfg.prompt_function): hlog_warn( - f"Be careful you are using custom prompt function {cfg['prompt_function']} and not the default one." + f"Be careful you are using custom prompt function {cfg.prompt_function} and not the default one." ) - self.formatter = getattr(custom_tasks_module, cfg["prompt_function"]) + self.formatter = getattr(custom_tasks_module, cfg.prompt_function) else: - self.formatter = getattr(tasks_prompt_formatting, cfg["prompt_function"]) - self.generation_size = cfg["generation_size"] - self.stop_sequence = cfg["stop_sequence"] - self.output_regex = cfg["output_regex"] + self.formatter = getattr(tasks_prompt_formatting, cfg.prompt_function) + self.generation_size = cfg.generation_size + self.stop_sequence = cfg.stop_sequence + self.output_regex = cfg.output_regex # Save options self.save_queries: bool = False diff --git a/src/lighteval/tasks/registry.py b/src/lighteval/tasks/registry.py index 6738aad5a..1e7db339e 100644 --- a/src/lighteval/tasks/registry.py +++ b/src/lighteval/tasks/registry.py @@ -10,7 +10,7 @@ from datasets.load import dataset_module_factory from lighteval.logging.hierarchical_logger import hlog, hlog_warn -from lighteval.tasks.lighteval_task import LightevalTask +from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig # original is the reimplementation of original evals @@ -202,7 +202,7 @@ def create_config_tasks( Dict[str, LightevalTask]: A dictionary of task names mapped to their corresponding LightevalTask classes. """ - def create_task(name, cfg, cache_dir): + def create_task(name, cfg: LightevalTaskConfig, cache_dir: str): class LightevalTaskFromConfig(LightevalTask): def __init__(self, custom_tasks_module=None): super().__init__(name, cfg, cache_dir=cache_dir, custom_tasks_module=custom_tasks_module) @@ -222,6 +222,6 @@ def __init__(self, custom_tasks_module=None): continue for suite in line["suite"]: if suite in DEFAULT_SUITES: - tasks_with_config[f"{suite}|{line['name']}"] = line + tasks_with_config[f"{suite}|{line['name']}"] = LightevalTaskConfig(**line) return {task: create_task(task, cfg, cache_dir=cache_dir) for task, cfg in tasks_with_config.items()} diff --git a/tasks_examples/custom_tasks/custom_evaluation_tasks.py b/tasks_examples/custom_tasks/custom_evaluation_tasks.py index f579b7dd5..0ed928e59 100644 --- a/tasks_examples/custom_tasks/custom_evaluation_tasks.py +++ b/tasks_examples/custom_tasks/custom_evaluation_tasks.py @@ -9,38 +9,38 @@ from typing import Dict, List, Tuple from lighteval.metrics import Metrics -from lighteval.tasks.lighteval_task import CustomEvaluationTask +from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc from lighteval.tasks.tasks_prompt_formatting import LETTER_INDICES -_TASKS_STRINGS: List[Tuple[CustomEvaluationTask, str]] = [] -_TASKS: List[CustomEvaluationTask] = [] +_TASKS_STRINGS: List[Tuple[LightevalTaskConfig, str]] = [] +_TASKS: List[LightevalTaskConfig] = [] ## COMMON_SENSE_REASONING_TASKS ## COMMON_SENSE_REASONING_TASKS = [ - CustomEvaluationTask( + LightevalTaskConfig( name="hellaswag", prompt_function="hellaswag_prompt", hf_repo="hellaswag", hf_subset="default", metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], ), - CustomEvaluationTask( + LightevalTaskConfig( name="winogrande", prompt_function="winogrande", hf_repo="winogrande", hf_subset="winogrande_xl", metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], ), - CustomEvaluationTask( + LightevalTaskConfig( name="piqa", prompt_function="piqa_harness", hf_repo="piqa", hf_subset="plain_text", metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], ), - CustomEvaluationTask( + LightevalTaskConfig( name="siqa", prompt_function="siqa_prompt", hf_repo="lighteval/siqa", @@ -48,14 +48,14 @@ hf_avail_splits=["train", "validation"], metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], ), - CustomEvaluationTask( + LightevalTaskConfig( name="openbookqa", prompt_function="openbookqa", hf_repo="openbookqa", hf_subset="main", metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], ), - CustomEvaluationTask( + LightevalTaskConfig( name="arc:easy", prompt_function="arc", hf_repo="ai2_arc", @@ -64,7 +64,7 @@ generation_size=1, metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], ), - CustomEvaluationTask( + LightevalTaskConfig( name="arc:challenge", prompt_function="arc", hf_repo="ai2_arc", @@ -73,7 +73,7 @@ generation_size=1, metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], ), - CustomEvaluationTask( + LightevalTaskConfig( name="commonsense_qa", prompt_function="commonsense_qa_prompt", hf_repo="commonsense_qa", @@ -131,7 +131,7 @@ def preprocess(text): ## WORLD_KNOWLEDGE_TASKS ## WORLD_KNOWLEDGE_TASKS = [ - CustomEvaluationTask( + LightevalTaskConfig( name="trivia_qa", prompt_function="triviaqa", hf_repo="trivia_qa", @@ -140,7 +140,7 @@ def preprocess(text): generation_size=20, stop_sequence=["\n", ".", ","], ), - CustomEvaluationTask( + LightevalTaskConfig( name="natural_questions", prompt_function="natural_questions_prompt", hf_repo="lighteval/natural_questions_clean", @@ -170,14 +170,14 @@ def natural_questions_prompt(line, task_name: str = None): ## Reading comprehension ## READING_COMP_TASKS = [ - CustomEvaluationTask( + LightevalTaskConfig( name="super_glue:boolq", prompt_function="boolq_prompt", hf_repo="super_glue", hf_subset="boolq", metric=["target_perplexity"], ), - CustomEvaluationTask( + LightevalTaskConfig( name="quac", prompt_function="quac", hf_repo="lighteval/quac_helm", @@ -204,7 +204,7 @@ def boolq_prompt(line, task_name: str = None): ## MATH ## -class CustomMathEvaluationTask(CustomEvaluationTask): +class CustomMathEvaluationTask(LightevalTaskConfig): """Custom class for math tasks with all the defaults set""" def __init__( @@ -251,7 +251,7 @@ def __init__( CustomMathEvaluationTask(name="math:prealgebra", hf_subset="prealgebra"), CustomMathEvaluationTask(name="math:precalculus", hf_subset="precalculus"), ] -GSM8K = CustomEvaluationTask( +GSM8K = LightevalTaskConfig( name="gsm8k", prompt_function="gsm8k", hf_repo="gsm8k", @@ -272,7 +272,7 @@ def __init__( ## MMLU ## -class CustomMMLUEvaluationTask(CustomEvaluationTask): +class CustomMMLUEvaluationTask(LightevalTaskConfig): def __init__( self, name, @@ -415,7 +415,7 @@ def mmlu_prompt(line, task_name: str = None): ## BBH ## -class CustomBBHEvaluationTask(CustomEvaluationTask): +class CustomBBHEvaluationTask(LightevalTaskConfig): def __init__( self, name, @@ -506,7 +506,7 @@ def bbh_prompt(line, task_name: str = None): ## AGI eval ## -class CustomAGIEvalEvaluationTask(CustomEvaluationTask): +class CustomAGIEvalEvaluationTask(LightevalTaskConfig): def __init__( self, name, @@ -617,7 +617,7 @@ def agi_eval_prompt_no_letters(line, task_name: str = None): ## HUMAN EVAL ## -# human_eval = CustomEvaluationTask( +# human_eval = LightevalTaskConfig( # name="human_eval", # prompt_function="human_eval", # hf_repo="lighteval/human_eval",