diff --git a/bigcode_eval/utils.py b/bigcode_eval/utils.py index ff79c0e5f..a1bbddaa7 100644 --- a/bigcode_eval/utils.py +++ b/bigcode_eval/utils.py @@ -2,6 +2,7 @@ import math import re import warnings +import numpy as np from collections import defaultdict from typing import List, Optional @@ -220,6 +221,13 @@ def _parse_instruction(code, instruction_tokens): shift = len("```python") return code[idx + shift :] +def _remove_rightpad_tokens(generated_tokens, pad_token): + for i in range(len(generated_tokens)-1, -1, -1): + if generated_tokens[i] == pad_token: + generated_tokens = np.delete(generated_tokens, i) + else: + break + return generated_tokens def complete_code( task, @@ -315,6 +323,7 @@ def complete_code( generated_tasks = generated_tasks.cpu().numpy() for sample, generated_tokens in zip(generated_tasks, generated_tokens): + generated_tokens = _remove_rightpad_tokens(generated_tokens, tokenizer.pad_token_id) gen_token_dict[sample].append(generated_tokens) if save_every_k_tasks >= 1 and (step + 1) % save_every_k_tasks == 0: