Skip to content

Commit

Permalink
Experiment T5
Browse files Browse the repository at this point in the history
  • Loading branch information
neet committed Aug 26, 2024
1 parent 65cbf9b commit 5d9f857
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 9 deletions.
5 changes: 2 additions & 3 deletions src/ainu_lm_pipeline/components/get_mt_training_job_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def get_mt_training_job_spec(
"image_uri": train_image_uri,
"args": [
"train",
"mt",
"--base-model=google/mt5-small",
"t5",
"--base-tokenizer=aynumosir/sentencepiece-ainu",
f"--dataset-name={dataset_name}",
f"--dataset-revision={dataset_revision}",
"--num-train-epochs=20",
Expand All @@ -26,7 +26,6 @@ def get_mt_training_job_spec(
"--learning-rate=5e-4",
"--warmup-ratio=0.06",
"--weight-decay=0.01",
"--experiment-task-prefix=all",
f"--hub-model-id={hub_model_id}",
f"--push-to-hub={'yes' if push_to_hub else 'no'}",
],
Expand Down
2 changes: 1 addition & 1 deletion src/ainu_lm_pipeline/pipelines/ainu_mt_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def ainu_mt_pipeline(
train_image_uri=train_image_uri,
dataset_name=hf_dataset_repo,
dataset_revision=get_dataset_revision_op.output,
hub_model_id="aynumosir/mt5-base-ainu",
hub_model_id="aynumosir/t5-base-ainu",
push_to_hub=True,
).set_display_name("MTジョブの仕様を取得")

Expand Down
11 changes: 11 additions & 0 deletions src/ainu_lm_trainer/app/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def add_parser(parser: ArgumentParser) -> None:
gpt2_parser = subparsers.add_parser("gpt2", parents=common)
gpt2_parser.add_argument("--base-tokenizer", type=str)

t5_parser = subparsers.add_parser("t5", parents=common)
t5_parser.add_argument("--base-tokenizer", type=str)

subparsers.add_parser("byte-level-bpe", parents=common)
subparsers.add_parser("sentencepiece", parents=common)

Expand Down Expand Up @@ -151,6 +154,14 @@ def main(args: Namespace) -> None:
config_workspace=config_workspace,
)

if args.task == "t5":
pretraining.t5.train(
tokenizer_name=args.base_tokenizer,
config_dataset=config_dataset,
config_training=config_training,
config_workspace=config_workspace,
)

if args.task == "byte-level-bpe":
bpe_trainer = pretraining.ByteLevelBpeTokenizerTrainer(
config_dataset=config_dataset,
Expand Down
3 changes: 1 addition & 2 deletions src/ainu_lm_trainer/trainers/fine_tuning/mt/mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
TrainingConfig,
WorkspaceConfig,
)
from ....utils import HyperparameterTuningCallback
from . import task_prefix
from ....utils import HyperparameterTuningCallback, task_prefix

Check warning on line 22 in src/ainu_lm_trainer/trainers/fine_tuning/mt/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/fine_tuning/mt/mt.py#L22

Added line #L22 was not covered by tests

sacrebleu = evaluate.load("sacrebleu")

Expand Down
2 changes: 1 addition & 1 deletion src/ainu_lm_trainer/trainers/pretraining/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import gpt2, roberta
from . import gpt2, roberta, t5
from .byte_level_bpe import ByteLevelBpeTokenizerTrainer
from .sentencepiece import SentencepieceTokenizerTrainer
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def __batch_iterator(self, batch_size: int = 1_000) -> Iterator[Iterator[str]]:
yield self.__dataset[i : i + batch_size]["text"]
yield self.__dataset[i : i + batch_size]["translation"]

# 方言名がちゃんとトークナイズされること
for dialect in self.__dataset[i : i + batch_size]["dialect"]:
if dialect:
yield dialect

Check warning on line 32 in src/ainu_lm_trainer/trainers/pretraining/sentencepiece/sentencepiece.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/sentencepiece/sentencepiece.py#L30-L32

Added lines #L30 - L32 were not covered by tests

def __prepare(self) -> None:
self.__config_workspace.model_dir.mkdir(parents=True, exist_ok=True)

Expand All @@ -36,6 +41,7 @@ def train(self) -> SentencePieceUnigramTokenizer:

tokenizer.train_from_iterator(
iterator=self.__batch_iterator(),
unk_token="<unk>",
special_tokens=[
"<unk>",
"<pad>",
Expand Down
1 change: 1 addition & 0 deletions src/ainu_lm_trainer/trainers/pretraining/t5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mt import train
129 changes: 129 additions & 0 deletions src/ainu_lm_trainer/trainers/pretraining/t5/mt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import evaluate
import numpy as np
import torch
from datasets import DatasetDict, interleave_datasets
from transformers import (
DataCollatorForSeq2Seq,
EarlyStoppingCallback,
EvalPrediction,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
T5Config,
T5ForConditionalGeneration,
T5TokenizerFast,
)

from ....config import DatasetConfig, TrainingConfig, WorkspaceConfig
from ....utils import task_prefix

sacrebleu = evaluate.load("sacrebleu")


def compute_metrics(tokenizer: T5TokenizerFast, eval_preds: EvalPrediction) -> dict:
predictions, labels = eval_preds

Check warning on line 23 in src/ainu_lm_trainer/trainers/pretraining/t5/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/t5/mt.py#L23

Added line #L23 was not covered by tests

predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

Check warning on line 26 in src/ainu_lm_trainer/trainers/pretraining/t5/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/t5/mt.py#L25-L26

Added lines #L25 - L26 were not covered by tests

labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

Check warning on line 29 in src/ainu_lm_trainer/trainers/pretraining/t5/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/t5/mt.py#L28-L29

Added lines #L28 - L29 were not covered by tests

bleu = sacrebleu.compute(predictions=decoded_preds, references=decoded_labels)
return {"bleu": bleu["score"]}

Check warning on line 32 in src/ainu_lm_trainer/trainers/pretraining/t5/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/t5/mt.py#L31-L32

Added lines #L31 - L32 were not covered by tests


def train(
config_dataset: DatasetConfig,
config_training: TrainingConfig,
config_workspace: WorkspaceConfig,
tokenizer_name: str,
context_length: int = 128,
) -> None:
tokenizer = T5TokenizerFast.from_pretrained(tokenizer_name)

Check warning on line 42 in src/ainu_lm_trainer/trainers/pretraining/t5/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/t5/mt.py#L42

Added line #L42 was not covered by tests

config = T5Config.from_pretrained("google-t5/t5-base")

Check warning on line 44 in src/ainu_lm_trainer/trainers/pretraining/t5/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/t5/mt.py#L44

Added line #L44 was not covered by tests

model = T5ForConditionalGeneration(config)
model = model.to("cuda") if torch.cuda.is_available() else model

Check warning on line 47 in src/ainu_lm_trainer/trainers/pretraining/t5/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/t5/mt.py#L46-L47

Added lines #L46 - L47 were not covered by tests

dataset_dict = config_dataset.load()
dataset_dict = dataset_dict.filter(

Check warning on line 50 in src/ainu_lm_trainer/trainers/pretraining/t5/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/t5/mt.py#L49-L50

Added lines #L49 - L50 were not covered by tests
lambda example: len(example["text"]) > 0 and len(example["translation"]) > 0
)

# --------------------------------------
# Prepare Japanese to Ainu
# --------------------------------------
dataset_dict_ja2ain = dataset_dict.map(

Check warning on line 57 in src/ainu_lm_trainer/trainers/pretraining/t5/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/t5/mt.py#L57

Added line #L57 was not covered by tests
lambda example: {
"text": task_prefix.ja2ain(example) + example["translation"],
"text_target": example["text"],
},
remove_columns=dataset_dict.column_names["train"],
)

# --------------------------------------
# Prepare Ainu to Japanese
# --------------------------------------
dataset_dict_ain2ja = dataset_dict.map(

Check warning on line 68 in src/ainu_lm_trainer/trainers/pretraining/t5/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/t5/mt.py#L68

Added line #L68 was not covered by tests
lambda example: {
"text": task_prefix.ain2ja(example) + example["text"],
"text_target": example["translation"],
},
remove_columns=dataset_dict.column_names["train"],
)

dataset_dict = DatasetDict(

Check warning on line 76 in src/ainu_lm_trainer/trainers/pretraining/t5/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/t5/mt.py#L76

Added line #L76 was not covered by tests
{
"train": interleave_datasets(
[dataset_dict_ain2ja["train"], dataset_dict_ja2ain["train"]],
stopping_strategy="all_exhausted",
),
"test": interleave_datasets(
[dataset_dict_ain2ja["test"], dataset_dict_ja2ain["test"]],
stopping_strategy="all_exhausted",
),
}
)

dataset_dict = dataset_dict.map(

Check warning on line 89 in src/ainu_lm_trainer/trainers/pretraining/t5/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/t5/mt.py#L89

Added line #L89 was not covered by tests
lambda examples: tokenizer(
examples["text"],
text_target=examples["text_target"],
max_length=context_length,
truncation=True,
),
batched=True,
)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
training_args = Seq2SeqTrainingArguments(

Check warning on line 100 in src/ainu_lm_trainer/trainers/pretraining/t5/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/t5/mt.py#L99-L100

Added lines #L99 - L100 were not covered by tests
save_strategy="epoch",
evaluation_strategy="epoch",
output_dir=str(config_workspace.checkpoint_dir),
generation_max_length=context_length,
predict_with_generate=True,
load_best_model_at_end=True,
metric_for_best_model="bleu",
greater_is_better=True,
logging_dir=str(config_workspace.logging_dir),
report_to=["tensorboard"],
)
training_args = config_training.extend(training_args)

Check warning on line 112 in src/ainu_lm_trainer/trainers/pretraining/t5/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/t5/mt.py#L112

Added line #L112 was not covered by tests

trainer = Seq2SeqTrainer(

Check warning on line 114 in src/ainu_lm_trainer/trainers/pretraining/t5/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/t5/mt.py#L114

Added line #L114 was not covered by tests
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset_dict["train"],
eval_dataset=dataset_dict["test"],
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
compute_metrics=lambda eval_preds: compute_metrics(tokenizer, eval_preds),
)

trainer.train()
trainer.save_model(str(config_workspace.model_dir))

Check warning on line 125 in src/ainu_lm_trainer/trainers/pretraining/t5/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/t5/mt.py#L124-L125

Added lines #L124 - L125 were not covered by tests

if config_training.push_to_hub:
model.push_to_hub(training_args.hub_model_id)
tokenizer.push_to_hub(training_args.hub_model_id)

Check warning on line 129 in src/ainu_lm_trainer/trainers/pretraining/t5/mt.py

View check run for this annotation

Codecov / codecov/patch

src/ainu_lm_trainer/trainers/pretraining/t5/mt.py#L127-L129

Added lines #L127 - L129 were not covered by tests
1 change: 1 addition & 0 deletions src/ainu_lm_trainer/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from . import task_prefix
from .get_path_from_uri import get_path_from_uri, get_path_str_from_uri
from .hyperparameter_tuning_callback import HyperparameterTuningCallback
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ....config import TaskPrefixType
from ..config import TaskPrefixType


def __make_ainu_language_identifier(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ....config import TaskPrefixType
from ..config import TaskPrefixType
from .task_prefix import ain2ja

sentence = {
Expand Down

0 comments on commit 5d9f857

Please sign in to comment.