Skip to content

Commit

Permalink
Merge pull request #23 from TianyiQ/main
Browse files Browse the repository at this point in the history
fix(abstractions): data manipulation for dialogue generation
  • Loading branch information
TianyiQ authored Dec 4, 2024
2 parents 139ad05 + a16ff9c commit fd24631
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 23 deletions.
32 changes: 21 additions & 11 deletions examples/abstractions/finetuning_datamanip.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,29 @@ def dialogue_manipulation():
}
]
)
dialogue_data = llama8b_instruct.inference(
dialogue_data, "dialogue_data2", backend="sglang"
)
dialogue_data = dialogue_data.switch_role_to_user()
dialogue_data = llama8b_instruct.inference(
dialogue_data, "dialogue_data3", backend="sglang"
)
dialogue_data = dialogue_data.switch_role_to_assistant()

def converse():
nonlocal dialogue_data

dialogue_data = llama8b_instruct.inference(
dialogue_data, "dialogue_data2", backend="sglang"
)
dialogue_data = dialogue_data.switch_role_to_user()

dialogue_data = llama8b_instruct.inference(
dialogue_data, "dialogue_data3", backend="sglang"
)
dialogue_data = dialogue_data.switch_role_to_assistant()

for i in range(5):
converse()

print(list(dialogue_data.all_passages()))
print(list(dialogue_data.to_openai_format()))


if __name__ == "__main__":
# continue_pretrain()
# supervised_finetune()
# direct_preference_optimization()
continue_pretrain()
supervised_finetune()
direct_preference_optimization()
dialogue_manipulation()
26 changes: 15 additions & 11 deletions src/abstractions/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,22 +288,23 @@ def vllm_process_batch(
temperature=temperature, top_p=0.95, max_tokens=max_tokens
)

if not os.environ.get("ALLOW_EMPTY_INPUT") or not eval(
os.environ.get("ALLOW_EMPTY_INPUT")
if not os.environ.get("ALLOW_EMPTY_INSTRUCTION") or not eval(
os.environ.get("ALLOW_EMPTY_INSTRUCTION")
):
found = 0
for dic in sample_dicts:
if not dic.get("input"):
if not dic.get("instruction"):
if not found:
warnings.warn(
'In at least one sample, "input" field is missing or empty. Content from the "instruction" field will be copied to the "input" field. This behavior can be disabled by ALLOW_EMPTY_INPUT=1.'
'In at least one sample, "instruction" field is missing or empty. Content from the "input" field will be moved to the "instruction" field. This behavior can be disabled by ALLOW_EMPTY_INSTRUCTION=1.'
)
found = 1
dic["input"] = dic["instruction"]
dic["instruction"] = dic["input"]
del dic["input"]

prompts = [
fill_in_QA_template(
dic["instruction"], dic["input"], model_repoid_or_path=template_type
dic.get("instruction"), dic.get("input"), model_repoid_or_path=template_type
)
for dic in sample_dicts
]
Expand Down Expand Up @@ -518,18 +519,19 @@ def sglang_process_batch(
"""
nonlocal purpose

if not os.environ.get("ALLOW_EMPTY_INPUT") or not eval(
os.environ.get("ALLOW_EMPTY_INPUT")
if not os.environ.get("ALLOW_EMPTY_INSTRUCTION") or not eval(
os.environ.get("ALLOW_EMPTY_INSTRUCTION")
):
found = 0
for dic in sample_dicts:
if not dic.get("input"):
if not dic.get("instruction"):
if not found:
warnings.warn(
'In at least one sample, "input" field is missing or empty. Content from the "instruction" field will be copied to the "input" field. This behavior can be disabled by ALLOW_EMPTY_INPUT=1.'
'In at least one sample, "instruction" field is missing or empty. Content from the "input" field will be moved to the "instruction" field. This behavior can be disabled by ALLOW_EMPTY_INSTRUCTION=1.'
)
found = 1
dic["input"] = dic["instruction"]
dic["instruction"] = dic["input"]
del dic["input"]

dialogues = dict_to_dialogue_list(sample_dicts, purpose)
options_lists = [(dic["predict"] if "predict" in dic and isinstance(dic["predict"], list) else []) for dic in sample_dicts]
Expand Down Expand Up @@ -675,6 +677,8 @@ def dict_to_dialogue_list(

if purpose == "logprobs" and "predict" in dic and isinstance(dic["predict"], str):
res.append({"role": "assistant", "content": dic["predict"]})
elif "output" in dic:
res.append({"role": "assistant", "content": dic["output"]})

return res

Expand Down
13 changes: 12 additions & 1 deletion src/abstractions/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
import os
import json
import warnings
from functools import partial
import src.utils.text_utils as tu
from tqdm import tqdm
from src.abstractions.configs.templates_configs import *
from src.abstractions.backends import dict_to_dialogue_list


# helper function, escape spaces in paths
Expand Down Expand Up @@ -184,6 +186,15 @@ def copy(self, data_name: str = None) -> "Data":

cp.key_fields = self.key_fields.copy()
return cp

def to_openai_format(self) -> Iterable[List[Dict[str, str]]]:
"""
Convert the data to OpenAI format, where each dialogue is a list of dictionaries with string keys and string values.
Each dictionary represents a dialogue turn.
"""
convert_fn: Callable[[Dict], List[Dict]] = partial(dict_to_dialogue_list, purpose="logprobs")
for element in self.all_passages():
yield convert_fn(element)

def transform(
self,
Expand Down Expand Up @@ -357,7 +368,7 @@ def switch_role_to_user_fn(sample_dict: Dict) -> Dict:
all_histories = [h[i] for h in sample_dict.get("history", []) for i in range(2)]
all_histories = [dialogue_starter] + all_histories
assert len(all_histories) % 2 == 1
sample_dict["history"] = [[all_histories[i], all_histories[i + 1]] for i in range(len(all_histories)-1, 2)]
sample_dict["history"] = [[all_histories[i], all_histories[i + 1]] for i in range(0, len(all_histories)-1, 2)]
sample_dict["instruction"] = all_histories[-1]
sample_dict["system"] = user_system_prompt
return sample_dict
Expand Down

0 comments on commit fd24631

Please sign in to comment.