diff --git a/README.md b/README.md index aa3bb89e3..20941d607 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ Below are the features and tasks of this framework: - `StarCoderFIM`: which uses the default FIM tokens `"", "", ""`, and - `SantaCoderFIM`: which uses SantaCoder FIM tokens `"", "", ""` - [Mercury](https://huggingface.co/datasets/Elfsong/Mercury) for evaluating computational efficiency of **Python** code generation. + - [ENAMEL](https://github.com/q-rz/enamel) evaluates the efficiency ($\textnormal{eff@}k$) of generated code compared with **expert-written** reference solutions under HumanEval problems. More details about each task can be found in the documentation in [`docs/README.md`](https://github.com/bigcode-project/bigcode-evaluation-harness/blob/main/docs/README.md). ## Setup diff --git a/bigcode_eval/tasks/__init__.py b/bigcode_eval/tasks/__init__.py index 8162a5f1a..e94f4099b 100644 --- a/bigcode_eval/tasks/__init__.py +++ b/bigcode_eval/tasks/__init__.py @@ -5,7 +5,7 @@ concode, ds1000, gsm, humaneval, humanevalplus, humanevalpack, instruct_humaneval, instruct_wizard_humaneval, mbpp, mbppplus, multiple, parity, python_bugs, quixbugs, recode, santacoder_fim, - studenteval, mercury) + studenteval, mercury, enamel) TASK_REGISTRY = { **apps.create_all_tasks(), @@ -31,6 +31,7 @@ **santacoder_fim.create_all_tasks(), "studenteval": studenteval.StudentEval, "mercury": mercury.Mercury, + **enamel.create_all_tasks(), } ALL_TASKS = sorted(list(TASK_REGISTRY)) diff --git a/bigcode_eval/tasks/custom_metrics/enamel_eval.py b/bigcode_eval/tasks/custom_metrics/enamel_eval.py new file mode 100644 index 000000000..7cf1b82e8 --- /dev/null +++ b/bigcode_eval/tasks/custom_metrics/enamel_eval.py @@ -0,0 +1,348 @@ +from copy import deepcopy +import gc +import pickle +import time + +import io +import os +import sys +import resource +import platform +import contextlib + +import numpy as np + +def calc_exec_time(ts): # Hodges--Lehmann estimator + ts = np.array(ts) / 2. + ts = ts[None, :] + ts[:, None] + ts = ts[np.tril_indices_from(ts)] + return np.median(ts) + +def calc_eff(elapsed, ref, timeout): + return max(0., timeout - elapsed) / (timeout - ref) + +def calc_eff_at_k(e, k): # numerically stable implementation + n = len(e) + lbd = [k / n] + k_ = k - 1 + for r in range(n - 1, k_, -1): + lbd.append(lbd[-1] * (1 - k_ / r)) + lbd = np.flip(lbd) + e = np.sort(e)[k_ :] + return (lbd * e).sum() + +def calc_pass_at_k(n, c, k): # from the HumanEval paper + if n - c < k: return 1.0 + return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) + +class Test: # a test case + def __init__(self, input = None, answer = None, ref = None): + self.input = input + self.answer = answer + self.ref = ref # reference execution time + +class Refs: # references for efficiency evaluation + def __init__(self, tests, hardness): + neg_inf = float('-inf') + self.refs = [neg_inf] * len(hardness) + self.ref_max = neg_inf + self.lid = None + self.cid = None + # finds the longest reference execution time for calibration + for j, (size, tests_j) in enumerate(tests): + if hardness[j]: + for k, test in enumerate(tests_j): + if self.refs[j] < test.ref: + self.refs[j] = test.ref + if self.ref_max < test.ref: + self.ref_max = test.ref + self.lid = j + self.cid = k + +class EnamUnpickler(pickle.Unpickler): + CLS_DICT = {'enam.evaluate.Test': Test, 'enam.evaluate.Refs': Refs} + def find_class(self, module, name): + cls_name = f'{module}.{name}' + if cls_name in self.CLS_DICT: + return self.CLS_DICT[cls_name] + else: + return super().find_class(module, name) + +TPL_RUN = '''%s +%s +__t0 = time.time() +__output = %s(*__input) +__t1 = time.time() +''' # % (prompt, code, entry_point) # this should work no matter code includes prompt or not +TPL_TEST = '''%s + pass +%s +__accepted = __check(__input, __answer, __output) +''' # % (prompt, checker) + +def evaluate_one(code, problem, tests, refs, k, hardness, n_reps, memory_giga, timeout_factor, tolerence_sec): + timeout = timeout_factor * refs.ref_max + memory_bytes = memory_giga * (1024 ** 3) + n_levels = len(tests) + zero_effs = [0. for j in range(n_levels)] + effs = [] + for j, (size, tests_j) in enumerate(tests): + n_reps_j = n_reps[j] + level_elapsed = [] + level_break = False + for k, test in enumerate(tests_j): + elapsed = [None for rep in range(n_reps_j)] + for rep in range(n_reps_j): + scope = dict(time = time, input = None, print = None, __input = deepcopy(test.input)) # in case that the code modifies the input + try: + unsafe_timed_execute(TPL_RUN % (problem.prompt, code, problem.entry_point), scope, memory_bytes, timeout + tolerence_sec) + scope['__input'] = test.input + scope['__answer'] = test.answer # to prevent the code reading the answer + unsafe_execute(TPL_TEST % (problem.prompt, problem.checker), scope) # assuming that the checker does not modify the input + except TimeoutException as e: + level_break = True + break + except MemoryError as e: + level_break = True + break + except OverflowError as e: + level_break = True + break + except KeyboardInterrupt as e: + raise e + except BaseException as e: + return False, zero_effs + else: + if '__accepted' in scope and scope['__accepted']: + elapsed[rep] = scope['__t1'] - scope['__t0'] + else: + return False, zero_effs + if level_break: + break + else: + level_elapsed.append(calc_exec_time(elapsed).item()) + if level_break: + break + else: + effs.append(calc_eff(elapsed = max(level_elapsed), ref = refs.refs[j], timeout = timeout)) + if j == 0 and level_break: + return False, zero_effs + for j in range(len(effs), n_levels): + effs.append(0.) + return True, effs + +def compute_refs(problem, tests, n_reps, hardness): # computes the calibration factor of of execution time + for j in range(len(tests)): + if hardness[j]: + for k in range(len(tests[j][-1])): + test = tests[j][-1][k] + n_reps_j = n_reps[j] + elapsed = [None for rep in range(n_reps_j)] + for rep in range(n_reps_j): + scope = dict(time = time, __input = deepcopy(test.input)) # in case that the code modifies the input + unsafe_execute(TPL_RUN % (problem.prompt, problem.reference_solution, problem.entry_point), scope) # assuming that the reference solution is error-free + elapsed[rep] = scope['__t1'] - scope['__t0'] + test.ref = calc_exec_time(elapsed).item() + return Refs(tests = tests, hardness = hardness) + +def evaluate_all(problems, codes, tests, k, hardness, n_reps, memory_giga, timeout_factor, tolerence_sec): + if isinstance(k, int): + k = [k] + min_codes = min(len(codes_i) for codes_i in codes) + k = sorted({k_ for k_ in k if k_ <= min_codes}) + passes = [[] for k_ in k] + effs = [[] for k_ in k] + gc.collect() + for problem, codes_i, tests_i in zip(problems, codes, tests): + refs_i = compute_refs(problem = problem, tests = tests_i, n_reps = n_reps, hardness = hardness) + n_levels = len(tests_i) + problem_passes = [] + problem_effs = [] + for code in codes_i: + passed, code_effs = evaluate_one( + code = code, problem = problem, tests = tests_i, refs = refs_i, + k = k, hardness = hardness, n_reps = n_reps, memory_giga = memory_giga, + timeout_factor = timeout_factor, tolerence_sec = tolerence_sec) + problem_passes.append(passed) + problem_effs.append(code_effs) + for j, k_ in enumerate(k): + passes[j].append(calc_pass_at_k(n = len(problem_passes), c = sum(problem_passes), k = k_)) + effs[j].append(calc_eff_at_k(e = np.average(problem_effs, axis = 1, weights = hardness), k = k_)) + metrics = dict() + for k_, pass_k in zip(k, passes): + metrics[f'pass@{k_}'] = np.mean(pass_k).item() + for k_, eff_k in zip(k, effs): + metrics[f'eff@{k_}'] = np.mean(eff_k).item() + return metrics + +def might_catch_timeout_signal(generation, pattern_seq = (' while ', ' try:')): + i = 0 + for pattern in pattern_seq: + i = generarion.find(pattern, i) + if i == -1: + return False + i += len(pattern) + return True + +might_catch_timeout_signal.WARNING = """\ +We have detected that the generated code samples use `try ... except` within a loop, which might catch \ +our timeout signal and cause a dead loop. Since resolving this rare issue via `multiprocessing` would \ +significantly slow down the evaluation process for our large-scale inputs, we have decided not to resolve \ +this issue. If this issue does happen, please consider removing the corresponding code samples.""" + +"""The following functions are adapted from code_eval (@link https://huggingface.co/spaces/evaluate-metric/code_eval)""" + +def get_memory_usage(): + return sys.getsizeof(sys.modules[__name__]) + +@contextlib.contextmanager +def set_memory_limit(maximum_memory_bytes = None): + try: + if maximum_memory_bytes is not None: + _not_darwin = (not platform.uname().system == "Darwin") + _rlimit_as = resource.getrlimit(resource.RLIMIT_AS) + _rlimit_data = resource.getrlimit(resource.RLIMIT_DATA) + if _not_darwin: + _rlimit_stack = resource.getrlimit(resource.RLIMIT_STACK) + memory_limit = int(get_memory_usage() + maximum_memory_bytes) + resource.setrlimit(resource.RLIMIT_AS, (memory_limit, _rlimit_as[-1])) + resource.setrlimit(resource.RLIMIT_DATA, (memory_limit, _rlimit_data[-1])) + if _not_darwin: + resource.setrlimit(resource.RLIMIT_STACK, (memory_limit, _rlimit_stack[-1])) + yield + finally: + if maximum_memory_bytes is not None: + resource.setrlimit(resource.RLIMIT_AS, _rlimit_as) + resource.setrlimit(resource.RLIMIT_DATA, _rlimit_data) + if _not_darwin: + resource.setrlimit(resource.RLIMIT_STACK, _rlimit_stack) + +class TimeoutException(Exception): + pass + +def timeout_signal_handler(signum, frame): + raise TimeoutException("Timed out!") + +@contextlib.contextmanager +def set_time_limit(seconds): + import signal + signal.setitimer(signal.ITIMER_REAL, seconds) + signal.signal(signal.SIGALRM, timeout_signal_handler) + try: + yield + finally: + signal.setitimer(signal.ITIMER_REAL, 0) + +class WriteOnlyStringIO(io.StringIO): + def read(self, *args, **kwargs): + raise OSError + def readline(self, *args, **kwargs): + raise OSError + def readlines(self, *args, **kwargs): + raise OSError + def readable(self, *args, **kwargs): + return False + +class redirect_stdin(contextlib._RedirectStream): # type: ignore + _stream = "stdin" + +@contextlib.contextmanager +def swallow_io(): + stream = WriteOnlyStringIO() + with contextlib.redirect_stdout(stream): + with contextlib.redirect_stderr(stream): + with redirect_stdin(stream): + yield + +@contextlib.contextmanager +def chdir(root): + if root == ".": + yield + return + cwd = os.getcwd() + os.chdir(root) + try: + yield + except BaseException as exc: + raise exc + finally: + os.chdir(cwd) + +@contextlib.contextmanager +def create_tempdir(): + import tempfile + with tempfile.TemporaryDirectory() as dirname: + with chdir(dirname): + yield dirname + +@contextlib.contextmanager +def reliability_guard(): + """ + This disables various destructive functions and prevents the generated code + from interfering with the test (e.g. fork bomb, killing other processes, + removing filesystem files, etc.) + + WARNING + This function is NOT a security sandbox. Untrusted code, including, model- + generated code, should not be blindly executed outside of one. See the + Codex paper for more information about OpenAI's code sandbox, and proceed + with caution. + """ + + with create_tempdir(): + with swallow_io(): + try: + + import faulthandler + + faulthandler.disable() + + import builtins, os, shutil, subprocess + + os.environ["OMP_NUM_THREADS"] = "1" + + _keys = dict( + builtins = ('exit', 'quit'), + os = ('kill', 'system', 'putenv', 'remove', 'removedirs', 'rmdir', 'fchdir', 'setuid', 'fork', 'forkpty', 'killpg', 'rename', 'renames', 'truncate', 'replace', 'unlink', 'fchmod', 'fchown', 'chmod', 'chown', 'chroot', 'lchflags', 'lchmod', 'lchown', 'getcwd', 'chdir'), + shutil = ('rmtree', 'move', 'chown'), + subprocess = ('Popen',), + ) + _baks = dict() + for lib, keys in _keys.items(): + obj = locals()[lib] + _bak = dict() + for key in keys: + if hasattr(obj, key): + _bak[key] = getattr(obj, key) + _baks[lib] = _bak + + #__builtins__["help"] = None + + yield + finally: + for lib, keys in _keys.items(): + obj = locals()[lib] + for key, val in _baks[lib].items(): + setattr(obj, key, val) + +def unsafe_execute(program: str, exec_globals: dict): + try: + gc_bak = gc.isenabled() + gc.disable() + with reliability_guard(): + exec(program, exec_globals) + finally: + if gc_bak: + gc.enable() + +def unsafe_timed_execute(program: str, exec_globals: dict, maximum_memory_bytes: float, time_limit_seconds: float): + try: + gc_bak = gc.isenabled() + gc.disable() + with reliability_guard(): + with set_memory_limit(maximum_memory_bytes): + with set_time_limit(time_limit_seconds): + exec(program, exec_globals) + finally: + if gc_bak: + gc.enable() diff --git a/bigcode_eval/tasks/enamel.py b/bigcode_eval/tasks/enamel.py new file mode 100644 index 000000000..35c7be212 --- /dev/null +++ b/bigcode_eval/tasks/enamel.py @@ -0,0 +1,134 @@ +"""How efficient is LLM-generated code? A rigorous & high-standard benchmark +https://arxiv.org/pdf/2406.06647 + +ENAMEL is a rigorous & high-standard benchmark for evaluating the efficiency of generated +Python code compared with expert-written reference solutions under 142 HumanEval problems + +Homepage: https://github.com/q-rz/enamel +""" + +_CITATION = """ +@article{qiu2024enamel, + title={How efficient is {LLM}-generated code? A rigorous \& high-standard benchmark}, + author={Qiu, Ruizhong and Zeng, Weiliang Will and Tong, Hanghang and Ezick, James and Lott, Christopher}, + journal={arXiv preprint arXiv:2406.06647}, + year={2024} +} +""" + + +from warnings import warn +import pickle +import numpy as np +from huggingface_hub import hf_hub_download +from bigcode_eval.tasks.humaneval import GeneralHumanEval +from bigcode_eval.tasks.custom_metrics.enamel_eval import EnamUnpickler, evaluate_all, might_catch_timeout_signal + + +class GeneralENAMEL(GeneralHumanEval): + """A task represents an entire benchmark including its dataset, problems, + answers, generation settings and evaluation methods. + """ + + DATASET_PATH = "q-rz/enamel" + DATASET_NAME = "default" + DATASET_ALL = "ENAMEL_HumanEval" + + def __init__(self, subset, # list of problem IDs + hardness=[0., 3., 3., 4.], n_reps = 6, memory_giga=10., timeout_factor=2., tolerence_sec=0.01, tests_path="cache/eval~tests.pkl", + strip_prompt=True, k=[1, 10, 100], + ): + super().__init__(strip_prompt=strip_prompt, k=k, num_workers=1, timeout=None) # each problem has a different time limit + self.subset = subset if isinstance(subset, list) else list(subset) + self.n_probs = len(self.subset) + self.dataset = self.dataset[self.DATASET_ALL].to_pandas().iloc[self.subset, :] + self.prob_ids = {row.task_id: i for i, row in enumerate(self.dataset.itertuples(index=False))} + self.hardness = hardness + self.n_levels = len(self.hardness) + self.n_reps = [n_reps if self.hardness[j] else 1 for j in range(self.n_levels)] # no need to repeat if it does not count into the efficiency score + self.memory_giga = memory_giga + self.timeout_factor = timeout_factor + self.tolerence_sec = tolerence_sec + if self.DATASET_PATH != 'q-rz/enamel': + warn(f"Tests are loaded from {self.DATASET_PATH}/{tests_path} by `pickle`. Unpickling files from an unknown provider can be unsafe.") + self.tests_path = hf_hub_download(repo_id = self.DATASET_PATH, filename = tests_path, repo_type = "dataset") + with open(self.tests_path, 'rb') as fi: + tests_all, _ = EnamUnpickler(fi).load() + self.tests = [tests_all[i] for i in self.subset] + + def get_dataset(self): + """Returns dataset as an iterable of namedtuple""" + return list(self.dataset.itertuples(index=False)) + + def get_prompt(self, doc): + """ + :param doc: namedtuple + a row from the dataset + :return: str + """ + return super().get_prompt(doc._asdict()) + + def get_reference(self, doc): + """ + :param doc: namedtuple + a row from the dataset + :return: tuple (problem, tests) + """ + i = self.prob_ids[doc.task_id] + return doc, self.tests[i] + + def postprocess_generation(self, generation, idx): + """ + Defines the postprocessing for a LM generation. + :param generation: str + code generation from LM + :param idx: int (if needed) + index of doc in the dataset to which the generation belongs; not needed here + :return: str + """ + generation = self._stop_at_stop_token(generation, self.stop_words) + if (not self.warned_dead_loop) and might_catch_timeout_signal(generation): + warn(might_catch_timeout_signal.WARNING) + self.warned_dead_loop = True + return generation + + def process_results(self, generations, references): + """ + Takes the list of LM generations and evaluates them against ground truth references, + returning the metric for the generations as in {"metric_name": result}. + :param generations: list(list(str)) + list of lists containing generations + :param references: list(str) + list of str containing refrences + :return: dict[str: float] + """ + problems = [] + tests = [] + for problem, tests_i in references: + problems.append(problem) + tests.append(tests_i) + return evaluate_all( + problems=problems, codes=generations, tests=tests, + k=self.k, hardness=self.hardness, n_reps=self.n_reps, + memory_giga=self.memory_giga, timeout_factor=self.timeout_factor, tolerence_sec=self.tolerence_sec, + ) + + +def create_task(name, subset): + class ENAMEL(GeneralENAMEL): + __name__ = name + __qualname__ = name + SUBSET = subset + def __init__(self, *args, **kwargs): + super().__init__(subset=self.SUBSET, *args, **kwargs) + return ENAMEL + +def create_all_tasks(): + """Creates a dictionary of tasks from a list of levels + :return: {task_name: task} + """ + return { + "enamel": create_task(name="ENAMEL", subset=sorted(set(range(164)) - {2, 23, 41, 45, 53, 60, 71, 92, 97, 99, 102, 123, 124, 135, 137, 138, 144, 148, 156, 157, 159, 160})), + "enamel-algo": create_task(name="ENAMEL_Algo", subset=[10, 18, 36, 39, 40, 43, 46, 49, 55, 59, 63, 76, 83, 96, 107, 109, 114, 129, 147, 154]), + "enamel-impl": create_task(name="ENAMEL_Impl", subset=[1, 5, 8, 9, 11, 12, 15, 16, 17, 19, 21, 22, 24, 25, 26, 27, 31, 33, 37, 38, 44, 48, 49, 50, 51, 52, 56, 57, 58, 59, 61, 64, 66, 69, 70, 72, 73, 74, 75, 78, 80, 82, 85, 87, 89, 91, 93, 94, 95, 96, 98, 100, 104, 105, 108, 110, 111, 112, 113, 116, 117, 118, 121, 122, 125, 127, 128, 131, 140, 142, 143, 150, 152, 155, 161]), + } diff --git a/docs/README.md b/docs/README.md index 903c6a122..e652061bf 100644 --- a/docs/README.md +++ b/docs/README.md @@ -426,6 +426,27 @@ accelerate launch main.py \ --metric_output_path .json ``` +## ENAMEL + +[ENAMEL](https://github.com/q-rz/enamel) is a rigorous & high-standard benchmark for evaluating the efficiency of generated Python code under large-scale inputs. It supports a new efficiency metric called eff@k, which generalizes the pass@k metric. Besides that, it provides expert-written reference solutions and expert-written test case generators, thus setting a high-standard for efficiency evaluation. See [this paper](https://arxiv.org/abs/2406.06647) for detail. + +**Notice:** It is NOT recommended to use multiple threads or processes in efficiency evaluation. That can negatively affect efficiency results. + +```python +accelerate launch main.py \ + --model \ + --max_length_generation 2048 \ + --tasks enamel \ + --temperature 0.8 \ + --top_p 0.95 \ + --do_sample True \ + --n_samples 10 \ + --batch_size 10 \ + --allow_code_execution +``` + +This implementation also supports the two subsets Algo and Impl in the paper: `--task enamel-algo` / `--task enamel-impl`. + ## Code generation benchmarks without unit tests For these tasks, we do single generations and compare the generated code against reference solutions and compute BLEU score. For the following tasks, we use a two-shot setting where we include 2 inputs and their solutions in the prompt, all preceded by an instruction such as: ` "Answer the following instructions in a one line SQL query:\n"`. The solutions consist of one line so we stop the generation when a new line is generated. 3 languages are present: Python, SQL and Java.