Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
neet committed Jul 1, 2024
1 parent b073350 commit 43d80ff
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
# https://huggingface.co/docs/tokenizers/en/training_from_memory#using-the-datasets-library
def __batch_iterator(self, batch_size: int = 1_000) -> Iterator[Iterator[str]]:
for i in range(0, len(self.__dataset), batch_size):
yield self.__dataset[i : i + batch_size]["sentence"]
yield self.__dataset[i : i + batch_size]["text"]

Check warning on line 23 in src/ainu_lm_trainer/services/byte_level_bpe/byte_level_bpe_tokenizer_trainer.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/services/byte_level_bpe/byte_level_bpe_tokenizer_trainer.py#L23

Added line #L23 was not covered by tests

def __prepare(self) -> None:
self.__workspace_config.model_dir.mkdir(parents=True, exist_ok=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def test_train() -> None:
dataset = Dataset.from_dict(
{
"sentence": [
"text": [
"This is a sentence.",
"This is another sentence.",
]
Expand Down
2 changes: 1 addition & 1 deletion src/ainu_lm_trainer/services/gpt2/gpt2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def train(self) -> None:
dataset = self.__dataset_config.load()
dataset = dataset.map(
lambda examples: tokenizer(
examples["sentence"],
examples["text"],
truncation=True,
max_length=self.__context_length,
padding="max_length",
Expand Down
24 changes: 12 additions & 12 deletions src/ainu_lm_trainer/services/mt5/mt5_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
class Mt5Trainer:
__context_length = 128
__model_name = "aynumosir/mt5-small-ainu"
__task_prefix = "translate: Japanese to Ainu: "

__dataset_config: DatasetConfig
__fine_tuning_config: FineTuningConfig
Expand All @@ -51,22 +50,17 @@ def train(self) -> None:

dataset = self.__dataset_config.load()
dataset = dataset.filter(
lambda example: len(example["sentence"]) > 0
and len(example["translation"]) > 0
)
# 文語に括弧書きで口語がある資料。除外する。
dataset = dataset.filter(
lambda example: not (
example["book"] == "鍋沢元蔵筆録ノート"
and example["title"] == "kamuyyukar-2"
)
lambda example: len(example["text"]) > 0 and len(example["translation"]) > 0
)

# https://huggingface.co/docs/transformers/en/tasks/summarization#preprocess
def preprocess(examples: dict) -> dict:
inputs = tokenizer(
[self.__task_prefix + text for text in examples["translation"]],
text_target=examples["sentence"],
[
self.__get_task_prefix(example) + example["translation"]
for example in examples
],
text_target=examples["text"],
max_length=self.__context_length,
truncation=True,
)
Expand Down Expand Up @@ -119,6 +113,12 @@ def preprocess(examples: dict) -> dict:
model.push_to_hub(self.__model_name)
tokenizer.push_to_hub(self.__model_name)

def __get_task_prefix(self, example: dict) -> str:
if "dialect" in example:
return f"translate Japanese to Ainu ({example['dialect']}, {example['pronoun']}): "

Check warning on line 118 in src/ainu_lm_trainer/services/mt5/mt5_trainer.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/services/mt5/mt5_trainer.py#L117-L118

Added lines #L117 - L118 were not covered by tests
else:
return f"translate Japanese to Ainu (沙流, {example['pronoun']}): "

Check warning on line 120 in src/ainu_lm_trainer/services/mt5/mt5_trainer.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/services/mt5/mt5_trainer.py#L120

Added line #L120 was not covered by tests

def __compute_metrics(
self, tokenizer: MT5Tokenizer, eval_preds: EvalPrediction
) -> dict:
Expand Down
6 changes: 3 additions & 3 deletions src/ainu_lm_trainer/services/mt5_affix/mt5_affix_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def train(self) -> None:
model = model.to("cuda") if torch.cuda.is_available() else model

dataset = self.__dataset_config.load()
dataset = dataset.filter(lambda example: len(example["sentence"]) > 0)
dataset = dataset.filter(lambda example: len(example["text"]) > 0)

Check warning on line 52 in src/ainu_lm_trainer/services/mt5_affix/mt5_affix_trainer.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/services/mt5_affix/mt5_affix_trainer.py#L52

Added line #L52 was not covered by tests
dataset = dataset.map(
lambda example: {
"text": example["sentence"].replace("=", ""),
"target": example["sentence"],
"text": example["text"].replace("=", ""),
"target": example["text"],
},
remove_columns=dataset.column_names,
)
Expand Down
12 changes: 10 additions & 2 deletions src/ainu_lm_trainer/services/mt5_gec/mt5_gec_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
class Mt5GecTrainer:
__context_length = 128
__model_name = "aynumosir/mt5-small-ainu-gec"
__task_prefix = "fix Ainu sentence: "

__fine_tuning_config: FineTuningConfig
__training_config: TrainingConfig
Expand Down Expand Up @@ -53,7 +52,10 @@ def train(self) -> None:
# https://huggingface.co/docs/transformers/en/tasks/summarization#preprocess
dataset = dataset.map(
lambda examples: tokenizer(
[self.__task_prefix + sentence for sentence in examples["text"]],
[
self.__get_task_prefix(example) + example["text"]
for example in examples
],
text_target=[text for text in examples["target"]],
max_length=self.__context_length,
padding="max_length",
Expand Down Expand Up @@ -100,3 +102,9 @@ def train(self) -> None:
if self.__training_config.push_to_hub:
model.push_to_hub(self.__model_name)
tokenizer.push_to_hub(self.__model_name)

def __get_task_prefix(self, example: dict) -> str:
if "dialect" in example:
return f"fix Ainu sentence ({example['dialect']}, {example['pronoun']}): "

Check warning on line 108 in src/ainu_lm_trainer/services/mt5_gec/mt5_gec_trainer.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/services/mt5_gec/mt5_gec_trainer.py#L107-L108

Added lines #L107 - L108 were not covered by tests
else:
return f"fix Ainu sentence (沙流, {example['pronoun']}): "

Check warning on line 110 in src/ainu_lm_trainer/services/mt5_gec/mt5_gec_trainer.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/services/mt5_gec/mt5_gec_trainer.py#L110

Added line #L110 was not covered by tests
2 changes: 1 addition & 1 deletion src/ainu_lm_trainer/services/roberta/roberta_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def train(self) -> None:
dataset = self.__dataset_config.load()
dataset = dataset.map(
lambda examples: tokenizer(
examples["sentence"],
examples["text"],
truncation=True,
max_length=self.__context_length,
padding="max_length",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def test_compact_dataset() -> None:
dataset = Dataset.from_dict(
{
"sentence": [
"text": [
"this is a 1st test sentence",
"this is a 2nd test sentence",
"this is a 3rd test sentence",
Expand Down

0 comments on commit 43d80ff

Please sign in to comment.