diff --git a/src/alignment/configs.py b/src/alignment/configs.py index d785e169..9097d94d 100644 --- a/src/alignment/configs.py +++ b/src/alignment/configs.py @@ -1,5 +1,4 @@ # coding=utf-8 -# coding=utf-8 # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/src/alignment/data.py b/src/alignment/data.py index 838169a1..da4b979e 100644 --- a/src/alignment/data.py +++ b/src/alignment/data.py @@ -1,5 +1,4 @@ # coding=utf-8 -# coding=utf-8 # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -72,12 +71,15 @@ def _strip_prefix(s, pattern): example["text_prompt"] = tokenizer.apply_chat_template( prompt_messages, tokenize=False, add_generation_prompt=True ) - - example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix) - example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix) + example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix) + example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix) + else: + raise ValueError( + f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" + ) else: raise ValueError( - f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" + f"Task {task} not supported, please ensure that the provided task is one of {['sft', 'generation', 'rm', 'dpo']}" ) return example diff --git a/src/alignment/model_utils.py b/src/alignment/model_utils.py index 9463f2e2..b9d2315a 100644 --- a/src/alignment/model_utils.py +++ b/src/alignment/model_utils.py @@ -1,5 +1,4 @@ # coding=utf-8 -# coding=utf-8 # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License");