Skip to content

Commit

Permalink
janky seq2seq support (will be reverted)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vipitis committed Jun 25, 2024
1 parent c0a1569 commit 2d8ab1f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
10 changes: 10 additions & 0 deletions bigcode_eval/tasks/shadereval.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def __init__(self, prompt="minimal"):
requires_execution=True, #we run shadercode - could that be harmful? (all in the metric)
)
self.prompt = prompt # "minimal" or "full". "minimal" is the function header and comments before/after it, "full" is the whole code up untill the function declaration ends
# while we develop this dataset, we use a private dataset, so we overwrite the init call here, which silently fails -.- (warn doesn't show up)
self.dataset = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME, trust_remote_code=True, use_auth_token=True)

def get_dataset(self):
# TODO replace with subset once that is set up
Expand Down Expand Up @@ -160,6 +162,13 @@ def get_prompt(self, doc):
# only have one alternative, but could be more?
model_context += doc["model_ctx"]
return model_context

def get_prompt_encoder(self, doc):
"""
this is needed for seq2seq models, but not availabel by default?
"""
enc_prompt = doc["model_ctx"] + "<extra_id_0>" #magic token to trigger generation for CodeT5p?
return enc_prompt

def get_reference(self, doc):
# TODO: get the reference solution from a sample `doc` from the dataset
Expand Down Expand Up @@ -213,6 +222,7 @@ def postprocess_generation(self, generation, idx):
# from: https://huggingface.co/spaces/Vipitis/ShaderCoder/blob/main/utils/tree_utils.py#L45
# generation = ShaderCoder.utils.parse_functions(generation)[0].text.decode() #not easily imported...

print(generation)

# assemble into the full code with just the function replaced
ref = self.dataset["test"][idx]
Expand Down
6 changes: 3 additions & 3 deletions bigcode_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,11 @@ def complete_code(
)
else:
generated_tokens = model.generate(
decoder_input_ids=inputs,
# decoder_input_ids=inputs,
input_ids=batch["ids_encoder"][:, : batch["input_len_encoder"]],
num_return_sequences=batch_size,
decoder_start_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
# decoder_start_token_id=tokenizer.pad_token_id,
# eos_token_id=tokenizer.eos_token_id,
**gen_kwargs,
)
else:
Expand Down

0 comments on commit 2d8ab1f

Please sign in to comment.