diff --git a/bigcode_eval/tasks/shadereval.py b/bigcode_eval/tasks/shadereval.py index 3e5790b02..73ec6d47a 100644 --- a/bigcode_eval/tasks/shadereval.py +++ b/bigcode_eval/tasks/shadereval.py @@ -133,6 +133,8 @@ def __init__(self, prompt="minimal"): requires_execution=True, #we run shadercode - could that be harmful? (all in the metric) ) self.prompt = prompt # "minimal" or "full". "minimal" is the function header and comments before/after it, "full" is the whole code up untill the function declaration ends + # while we develop this dataset, we use a private dataset, so we overwrite the init call here, which silently fails -.- (warn doesn't show up) + self.dataset = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME, trust_remote_code=True, use_auth_token=True) def get_dataset(self): # TODO replace with subset once that is set up @@ -160,6 +162,13 @@ def get_prompt(self, doc): # only have one alternative, but could be more? model_context += doc["model_ctx"] return model_context + + def get_prompt_encoder(self, doc): + """ + this is needed for seq2seq models, but not availabel by default? + """ + enc_prompt = doc["model_ctx"] + "" #magic token to trigger generation for CodeT5p? + return enc_prompt def get_reference(self, doc): # TODO: get the reference solution from a sample `doc` from the dataset @@ -213,6 +222,7 @@ def postprocess_generation(self, generation, idx): # from: https://huggingface.co/spaces/Vipitis/ShaderCoder/blob/main/utils/tree_utils.py#L45 # generation = ShaderCoder.utils.parse_functions(generation)[0].text.decode() #not easily imported... + print(generation) # assemble into the full code with just the function replaced ref = self.dataset["test"][idx] diff --git a/bigcode_eval/utils.py b/bigcode_eval/utils.py index ff79c0e5f..444eaff6c 100644 --- a/bigcode_eval/utils.py +++ b/bigcode_eval/utils.py @@ -281,11 +281,11 @@ def complete_code( ) else: generated_tokens = model.generate( - decoder_input_ids=inputs, + # decoder_input_ids=inputs, input_ids=batch["ids_encoder"][:, : batch["input_len_encoder"]], num_return_sequences=batch_size, - decoder_start_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, + # decoder_start_token_id=tokenizer.pad_token_id, + # eos_token_id=tokenizer.eos_token_id, **gen_kwargs, ) else: