Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
neet committed Jul 3, 2024
1 parent fffd571 commit 40a65c6
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions src/ainu_lm_trainer/services/mt5/mt5_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,23 @@ def train(self) -> None:
lambda example: len(example["text"]) > 0 and len(example["translation"]) > 0
)

# Set task prefix
dataset = dataset.map(

Check warning on line 57 in src/ainu_lm_trainer/services/mt5/mt5_trainer.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/services/mt5/mt5_trainer.py#L57

Added line #L57 was not covered by tests
lambda example: {
"text": self.__get_task_prefix(example) + example["translation"],
"target": example["text"],
},
remove_columns=dataset.column_names,
)

# https://huggingface.co/docs/transformers/en/tasks/summarization#preprocess
def preprocess(examples: dict) -> dict:
inputs = tokenizer(
[
self.__get_task_prefix(example) + example["translation"]
for example in examples
],
text_target=examples["text"],
return tokenizer(

Check warning on line 67 in src/ainu_lm_trainer/services/mt5/mt5_trainer.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/services/mt5/mt5_trainer.py#L67

Added line #L67 was not covered by tests
examples["text"],
text_target=examples["target"],
max_length=self.__context_length,
truncation=True,
)
return inputs

dataset = dataset.map(
preprocess,
Expand Down Expand Up @@ -114,10 +119,10 @@ def preprocess(examples: dict) -> dict:
tokenizer.push_to_hub(self.__model_name)

def __get_task_prefix(self, example: dict) -> str:
if example["dialect"] is not None:
return f"translate: Japanese to Ainu ({example['dialect']}, {example['pronoun']}): "
if example["dialect"] is None:
return f"translate: Japanese to Ainu: (沙流, {example['pronoun']}): "

Check warning on line 123 in src/ainu_lm_trainer/services/mt5/mt5_trainer.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/services/mt5/mt5_trainer.py#L122-L123

Added lines #L122 - L123 were not covered by tests
else:
return f"translate: Japanese to Ainu (沙流, {example['pronoun']}): "
return f"translate: Japanese to Ainu: ({example['dialect']}, {example['pronoun']}): "

Check warning on line 125 in src/ainu_lm_trainer/services/mt5/mt5_trainer.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/services/mt5/mt5_trainer.py#L125

Added line #L125 was not covered by tests

def __compute_metrics(
self, tokenizer: MT5Tokenizer, eval_preds: EvalPrediction
Expand Down

0 comments on commit 40a65c6

Please sign in to comment.