From 76a69cca9a7f99f4bbde0545b89da95909129cf3 Mon Sep 17 00:00:00 2001 From: Ryo Igarashi Date: Mon, 25 Mar 2024 01:04:42 +0900 Subject: [PATCH] Use tensorboard on demand --- cspell.json | 6 +- poetry.lock | 88 +++++++++++++++++-- pyproject.toml | 3 +- .../components/get_worker_pool_specs.py | 4 + src/ainu_lm_pipeline/pipeline.py | 3 + src/ainu_lm_pipeline/submit.py | 10 +-- src/ainu_lm_trainer/app/argument_parser.py | 14 +-- .../app/argument_parser_test.py | 2 - src/ainu_lm_trainer/app/main.py | 9 +- .../app/task_language_model.py | 28 ++++-- src/ainu_lm_trainer/trainers/__init__.py | 1 + .../trainers/roberta_trainer.py | 45 +++++++--- .../trainers/roberta_trainer_callback.py | 36 ++++++++ .../trainers/roberta_trainer_config.py | 22 +++++ .../trainers/roberta_trainer_config_test.py | 29 ++++++ .../trainers/roberta_trainer_test.py | 10 ++- test.py | 14 --- 17 files changed, 266 insertions(+), 58 deletions(-) create mode 100644 src/ainu_lm_trainer/trainers/roberta_trainer_callback.py create mode 100644 src/ainu_lm_trainer/trainers/roberta_trainer_config.py create mode 100644 src/ainu_lm_trainer/trainers/roberta_trainer_config_test.py delete mode 100644 test.py diff --git a/cspell.json b/cspell.json index c81b34e..96c8b65 100644 --- a/cspell.json +++ b/cspell.json @@ -13,10 +13,14 @@ "gapic", "googleapis", "huggingface", + "hypertune", + "hypertuner", "idxmin", "kaniko", + "logdir", "neetlab", - "protobuf" + "protobuf", + "tensorboard" ], "ignoreWords": [], "import": [] diff --git a/poetry.lock b/poetry.lock index 9fa760f..4adc72a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,16 @@ # This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +[[package]] +name = "absl-py" +version = "2.1.0" +description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." +optional = false +python-versions = ">=3.7" +files = [ + {file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"}, + {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"}, +] + [[package]] name = "accelerate" version = "0.28.0" @@ -632,12 +643,12 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" @@ -1129,7 +1140,6 @@ protobuf = ">=4.21.1,<5" PyYAML = ">=5.3,<7" requests-toolbelt = ">=0.8.0,<1" tabulate = ">=0.8.6,<1" -typing-extensions = {version = ">=3.7.4,<5", markers = "python_version < \"3.9\""} urllib3 = "<2.0.0" [package.extras] @@ -1191,6 +1201,21 @@ websocket-client = ">=0.32.0,<0.40.0 || >0.40.0,<0.41.dev0 || >=0.43.dev0" [package.extras] adal = ["adal (>=1.0.2)"] +[[package]] +name = "markdown" +version = "3.6" +description = "Python implementation of John Gruber's Markdown." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"}, + {file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"}, +] + +[package.extras] +docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] +testing = ["coverage", "pyyaml"] + [[package]] name = "markupsafe" version = "2.1.5" @@ -1718,9 +1743,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -2538,6 +2562,39 @@ files = [ [package.extras] widechars = ["wcwidth"] +[[package]] +name = "tensorboard" +version = "2.16.2" +description = "TensorBoard lets you watch Tensors Flow" +optional = false +python-versions = ">=3.9" +files = [ + {file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"}, +] + +[package.dependencies] +absl-py = ">=0.4" +grpcio = ">=1.48.2" +markdown = ">=2.6.8" +numpy = ">=1.12.0" +protobuf = ">=3.19.6,<4.24.0 || >4.24.0" +setuptools = ">=41.0.0" +six = ">1.9" +tensorboard-data-server = ">=0.7.0,<0.8.0" +werkzeug = ">=1.0.1" + +[[package]] +name = "tensorboard-data-server" +version = "0.7.2" +description = "Fast data loading for TensorBoard" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, + {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, + {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"}, +] + [[package]] name = "tokenizers" version = "0.15.2" @@ -2899,6 +2956,23 @@ docs = ["Sphinx (>=6.0)", "sphinx-rtd-theme (>=1.1.0)"] optional = ["python-socks", "wsaccel"] test = ["websockets"] +[[package]] +name = "werkzeug" +version = "3.0.1" +description = "The comprehensive WSGI web application library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "werkzeug-3.0.1-py3-none-any.whl", hash = "sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10"}, + {file = "werkzeug-3.0.1.tar.gz", hash = "sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc"}, +] + +[package.dependencies] +MarkupSafe = ">=2.1.1" + +[package.extras] +watchdog = ["watchdog (>=2.3)"] + [[package]] name = "xxhash" version = "3.4.1" @@ -3121,5 +3195,5 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" -python-versions = ">=3.8,<3.12.0" -content-hash = "84e74dbd83bfacef188bd3c9885fe59b8438cb1f3fea90e3fcf0c6531c4eec22" +python-versions = ">=3.10,<3.12.0" +content-hash = "fa7e29eb696b167ec8ab607a7e01588eed07ea0777bdb7646f2a6dca1fda93bf" diff --git a/pyproject.toml b/pyproject.toml index e099cef..3829e18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ readme = "README.md" package-mode = false [tool.poetry.dependencies] -python = ">=3.8,<3.12.0" +python = ">=3.10,<3.12.0" datasets = "^2.18.0" transformers = "^4.38.2" sentencepiece = "^0.2.0" @@ -18,6 +18,7 @@ cloudml-hypertune = "^0.1.0.dev6" accelerate = "^0.28.0" google-cloud-aiplatform = "^1.44.0" google-cloud-pipeline-components = "^2.11.0" +tensorboard = "^2.16.2" [tool.poetry.group.dev.dependencies] diff --git a/src/ainu_lm_pipeline/components/get_worker_pool_specs.py b/src/ainu_lm_pipeline/components/get_worker_pool_specs.py index 46f556b..c49a6b9 100644 --- a/src/ainu_lm_pipeline/components/get_worker_pool_specs.py +++ b/src/ainu_lm_pipeline/components/get_worker_pool_specs.py @@ -8,6 +8,8 @@ def get_worker_pool_specs( train_image_uri: str, tokenizer_gcs_path: str, + tensorboard_id: str, + tensorboard_experiment_name: str, ) -> list: worker_pool_specs = [ { @@ -17,6 +19,8 @@ def get_worker_pool_specs( "language-model", "--hp-tune=True", f"--tokenizer={tokenizer_gcs_path}", + f"--tensorboard-id={tensorboard_id}", + f"--tensorboard-experiment-name={tensorboard_experiment_name}", ], }, "machine_spec": { diff --git a/src/ainu_lm_pipeline/pipeline.py b/src/ainu_lm_pipeline/pipeline.py index a9f2a46..48ce9cb 100644 --- a/src/ainu_lm_pipeline/pipeline.py +++ b/src/ainu_lm_pipeline/pipeline.py @@ -23,6 +23,7 @@ def ainu_lm_pipeline( pipeline_root: str, source_repo_name: str, source_commit_sha: str, + tensorboard_id: str, hf_repo: str, hf_token: str, ) -> None: @@ -85,6 +86,8 @@ def ainu_lm_pipeline( worker_pool_specs_task = ( get_worker_pool_specs( train_image_uri=cfg.TRAIN_IMAGE_URI, + tensorboard_id=tensorboard_id, + tensorboard_experiment_name=pipeline_job_id, tokenizer_gcs_path=tokenizer_training_job_details_task.outputs[ "model_artifacts" ], diff --git a/src/ainu_lm_pipeline/submit.py b/src/ainu_lm_pipeline/submit.py index f8ce0db..5cc45fd 100644 --- a/src/ainu_lm_pipeline/submit.py +++ b/src/ainu_lm_pipeline/submit.py @@ -1,15 +1,10 @@ import argparse import os -from datetime import datetime -import config as cfg from google.cloud import aiplatform from google.cloud.aiplatform.pipeline_jobs import PipelineJob - -def get_timestamp() -> str: - return datetime.now().strftime("%Y%m%d%H%M%S") - +from . import config as cfg parser = argparse.ArgumentParser() parser.add_argument("--commit-sha", type=str, required=True) @@ -19,13 +14,14 @@ def get_timestamp() -> str: aiplatform.init(project=cfg.PROJECT_ID, location=cfg.REGION) args = parser.parse_args() - job_id = f"pipeline-{cfg.APP_NAME}-{get_timestamp()}" + job_id = f"pipeline-ainu-lm-{args.commit_sha}" pipeline_params = { "pipeline_job_id": job_id, "pipeline_root": cfg.PIPELINE_ROOT, "source_repo_name": "github_aynumosir_ainu-lm", "source_commit_sha": args.commit_sha, + "tensorboard_id": os.environ.get("TENSORBOARD_ID"), "hf_repo": "aynumosir/roberta-ainu-base", "hf_token": os.environ.get("HF_TOKEN"), } diff --git a/src/ainu_lm_trainer/app/argument_parser.py b/src/ainu_lm_trainer/app/argument_parser.py index 8802e5b..bb81e6a 100644 --- a/src/ainu_lm_trainer/app/argument_parser.py +++ b/src/ainu_lm_trainer/app/argument_parser.py @@ -30,12 +30,6 @@ def get_argument_parser() -> argparse.ArgumentParser: default=False, help="Whether to use hyperparameter tuning", ) - language_model_parser.add_argument( - "--model-name", - type=str, - help="Model name to train (e.g. roberta-base-ainu)", - default=os.environ.get("MODEL_NAME"), - ) language_model_parser.add_argument( "--num-train-epochs", type=int, help="Number of training epochs", default=10 ) @@ -51,6 +45,14 @@ def get_argument_parser() -> argparse.ArgumentParser: help="Job directory. Use gs:/ to save to Google Cloud Storage", default=os.environ.get("AIP_MODEL_DIR"), ) + language_model_parser.add_argument( + "--tensorboard-id", + help="Tensorboard ID", + ) + language_model_parser.add_argument( + "--tensorboard-experiment-display-name", + help="Tensorboard experiment display name", + ) """ Subparser for the cache diff --git a/src/ainu_lm_trainer/app/argument_parser_test.py b/src/ainu_lm_trainer/app/argument_parser_test.py index cda014c..a21d08c 100644 --- a/src/ainu_lm_trainer/app/argument_parser_test.py +++ b/src/ainu_lm_trainer/app/argument_parser_test.py @@ -14,7 +14,6 @@ def test_parsing_language_model_training() -> None: [ "language_model", "--hp-tune=True", - "--model-name=test-model", "--num-train-epochs=20", "--tokenizer-dir=gs://test/tokenizer", "--job-dir=gs://test/job_dir", @@ -22,7 +21,6 @@ def test_parsing_language_model_training() -> None: ) assert args.task == "language_model" assert args.hp_tune == "True" - assert args.model_name == "test-model" assert args.num_train_epochs == 20 assert args.tokenizer_dir.bucket.name == "test" diff --git a/src/ainu_lm_trainer/app/main.py b/src/ainu_lm_trainer/app/main.py index 779327f..5215251 100644 --- a/src/ainu_lm_trainer/app/main.py +++ b/src/ainu_lm_trainer/app/main.py @@ -11,7 +11,14 @@ tokenizer(job_dir=args.job_dir) if args.task == "language_model": - language_model(job_dir=args.job_dir, tokenizer_blob=args.tokenizer_blob) + language_model( + job_dir=args.job_dir, + tokenizer_blob=args.tokenizer_dir, + num_train_epochs=args.num_train_epochs, + hypertune_enabled=args.hp_tune, + tensorboard_id=args.tensorboard_id, + tensorboard_experiment_name=args.tensorboard_experiment_display_name, + ) if args.task == "cache": cache() diff --git a/src/ainu_lm_trainer/app/task_language_model.py b/src/ainu_lm_trainer/app/task_language_model.py index 8aae00c..0d53674 100644 --- a/src/ainu_lm_trainer/app/task_language_model.py +++ b/src/ainu_lm_trainer/app/task_language_model.py @@ -1,15 +1,25 @@ import os from pathlib import Path +from typing import Optional from datasets import load_dataset -from google.cloud import storage +from google.cloud import aiplatform, storage from google.cloud.storage import Blob from ..models import JobDir -from ..trainers import RobertaTrainer +from ..trainers import RobertaTrainer, RobertaTrainerConfig -def language_model(job_dir: JobDir, tokenizer_blob: Blob) -> None: +def language_model( + job_dir: JobDir, + tokenizer_blob: Blob, + num_train_epochs: int, + hypertune_enabled: Optional[bool] = None, + tensorboard_id: Optional[str] = None, + tensorboard_experiment_name: Optional[str] = None, +) -> None: + aiplatform.init() + client = storage.Client() dataset = load_dataset("aynumosir/ainu-corpora", split="data") dataset = dataset.map(lambda example: {"text": example["sentence"]}) @@ -26,9 +36,17 @@ def language_model(job_dir: JobDir, tokenizer_blob: Blob) -> None: # Create output directory output_dir = Path("/tmp/ainu-lm-trainer/lm") output_dir.mkdir(parents=True, exist_ok=True) - trainer = RobertaTrainer( - dataset, tokenizer_name_or_dir=tokenizer_dir, output_dir=output_dir + + config = RobertaTrainerConfig( + num_train_epochs=num_train_epochs, + tokenizer_name_or_dir=tokenizer_dir, + output_dir=output_dir, + hypertune_enabled=hypertune_enabled, + tensorboard_id=tensorboard_id, + tensorboard_experiment_name=tensorboard_experiment_name, ) + + trainer = RobertaTrainer(dataset, config=config) trainer.train() paths = [ diff --git a/src/ainu_lm_trainer/trainers/__init__.py b/src/ainu_lm_trainer/trainers/__init__.py index 77d7344..4d3916b 100644 --- a/src/ainu_lm_trainer/trainers/__init__.py +++ b/src/ainu_lm_trainer/trainers/__init__.py @@ -1,2 +1,3 @@ from .byte_level_bpe_tokenizer_trainer import ByteLevelBPETokenizerTrainer from .roberta_trainer import RobertaTrainer +from .roberta_trainer_config import RobertaTrainerConfig diff --git a/src/ainu_lm_trainer/trainers/roberta_trainer.py b/src/ainu_lm_trainer/trainers/roberta_trainer.py index 212fef7..1e49fe5 100644 --- a/src/ainu_lm_trainer/trainers/roberta_trainer.py +++ b/src/ainu_lm_trainer/trainers/roberta_trainer.py @@ -2,6 +2,7 @@ import torch from datasets import Dataset +from google.cloud import aiplatform from transformers import ( DataCollatorForLanguageModeling, RobertaConfig, @@ -11,24 +12,29 @@ TrainingArguments, ) +from .roberta_trainer_callback import HPTuneCallback +from .roberta_trainer_config import RobertaTrainerConfig + class RobertaTrainer: dataset: Dataset - output_dir: Path - tokenizer_name_or_dir: Path | str + config: RobertaTrainerConfig + logging_dir: Path def __init__( - self, dataset: Dataset, tokenizer_name_or_dir: Path | str, output_dir: Path + self, + dataset: Dataset, + config: RobertaTrainerConfig, ) -> None: - self.output_dir = output_dir - self.tokenizer_name_or_dir = tokenizer_name_or_dir - if "text" not in dataset.column_names: raise ValueError('The dataset must have a column named "text"') else: self.dataset = dataset - def train(self, num_train_epochs: int = 10) -> None: + self.config = config + self.logging_dir = Path("./logs") + + def train(self) -> None: # FacebookAI/roberta-base よりも hidden_layers が少し小さい。エスペラントの記事を参考にした。 # https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/01_how_to_train.ipynb config = RobertaConfig( @@ -40,7 +46,7 @@ def train(self, num_train_epochs: int = 10) -> None: ) tokenizer = RobertaTokenizerFast.from_pretrained( - str(self.tokenizer_name_or_dir), + str(self.config.tokenizer_name_or_dir), max_length=512, padding="max_length", truncation=True, @@ -56,12 +62,14 @@ def train(self, num_train_epochs: int = 10) -> None: model = model.to("cuda") if torch.cuda.is_available() else model training_args = TrainingArguments( - output_dir=str(self.output_dir), + output_dir=str(self.config.output_dir), overwrite_output_dir=True, - num_train_epochs=num_train_epochs, + num_train_epochs=self.config.num_train_epochs, per_device_train_batch_size=64, save_steps=10_000, save_total_limit=2, + logging_dir=self.logging_dir, + report_to=["tensorboard"] if self.config.tensorboard_enabled else [], ) data_collator = DataCollatorForLanguageModeling( @@ -82,5 +90,20 @@ def train(self, num_train_epochs: int = 10) -> None: train_dataset=train_dataset, ) + if self.config.hypertune_enabled: + trainer.add_callback(HPTuneCallback("loss", "eval_loss")) + + if self.config.tensorboard_enabled: + aiplatform.start_upload_tb_log( + tensorboard_id=config.tensorboard_id, + tensorboard_experiment_name=config.tensorboard_experiment_name, + logdir=str(self.logging_dir), + run_name_prefix="roberta-base-ainu", + ) + trainer.train() - trainer.save_model(self.output_dir) + + if self.config.tensorboard_enabled: + aiplatform.end_upload_tb_log() + + trainer.save_model(self.config.output_dir) diff --git a/src/ainu_lm_trainer/trainers/roberta_trainer_callback.py b/src/ainu_lm_trainer/trainers/roberta_trainer_callback.py new file mode 100644 index 0000000..95c199a --- /dev/null +++ b/src/ainu_lm_trainer/trainers/roberta_trainer_callback.py @@ -0,0 +1,36 @@ +from typing import Dict + +import hypertune +from transformers import TrainerCallback +from transformers.trainer_callback import ( + TrainerControl, + TrainerState, + TrainingArguments, +) + + +class HPTuneCallback(TrainerCallback): + """ + A custom callback class that reports a metric to hypertuner + at the end of each epoch. + """ + + def __init__(self, metric_tag: str, metric_value: str) -> None: + super(HPTuneCallback, self).__init__() + self.metric_tag = metric_tag + self.metric_value = metric_value + self.hpt = hypertune.HyperTune() + + def on_evaluate( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs: Dict, + ) -> None: + print(f"HP metric {self.metric_tag}={kwargs['metrics'][self.metric_value]}") + self.hpt.report_hyperparameter_tuning_metric( + hyperparameter_metric_tag=self.metric_tag, + metric_value=kwargs["metrics"][self.metric_value], + global_step=state.epoch, + ) diff --git a/src/ainu_lm_trainer/trainers/roberta_trainer_config.py b/src/ainu_lm_trainer/trainers/roberta_trainer_config.py new file mode 100644 index 0000000..e88a3b4 --- /dev/null +++ b/src/ainu_lm_trainer/trainers/roberta_trainer_config.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + + +@dataclass +class RobertaTrainerConfig: + num_train_epochs: int + tokenizer_name_or_dir: Path | str + output_dir: Path + + hypertune_enabled: Optional[bool] = False + + tensorboard_id: Optional[str] = None + tensorboard_experiment_name: Optional[str] = None + + @property + def tensorboard_enabled(self) -> bool: + return ( + self.tensorboard_id is not None + and self.tensorboard_experiment_name is not None + ) diff --git a/src/ainu_lm_trainer/trainers/roberta_trainer_config_test.py b/src/ainu_lm_trainer/trainers/roberta_trainer_config_test.py new file mode 100644 index 0000000..8882193 --- /dev/null +++ b/src/ainu_lm_trainer/trainers/roberta_trainer_config_test.py @@ -0,0 +1,29 @@ +from pathlib import Path + +from .roberta_trainer_config import RobertaTrainerConfig + + +def test_roberta_trainer_config() -> None: + config = RobertaTrainerConfig( + num_train_epochs=1, + tokenizer_name_or_dir="tokenizer", + output_dir=Path("output"), + ) + + assert config.num_train_epochs == 1 + assert config.tokenizer_name_or_dir == "tokenizer" + assert str(config.output_dir) == "output" + assert config.hypertune_enabled is False + assert config.tensorboard_enabled is False + + +def test_tensorboard_enabled() -> None: + config = RobertaTrainerConfig( + num_train_epochs=1, + tokenizer_name_or_dir="tokenizer", + output_dir=Path("output"), + tensorboard_id="id", + tensorboard_experiment_name="name", + ) + + assert config.tensorboard_enabled is True diff --git a/src/ainu_lm_trainer/trainers/roberta_trainer_test.py b/src/ainu_lm_trainer/trainers/roberta_trainer_test.py index 3937107..fd62b47 100644 --- a/src/ainu_lm_trainer/trainers/roberta_trainer_test.py +++ b/src/ainu_lm_trainer/trainers/roberta_trainer_test.py @@ -3,6 +3,7 @@ from datasets import Dataset from .roberta_trainer import RobertaTrainer +from .roberta_trainer_config import RobertaTrainerConfig def test_compact_dataset() -> None: @@ -21,10 +22,13 @@ def test_compact_dataset() -> None: trainer = RobertaTrainer( dataset=dataset, - tokenizer_name_or_dir="roberta-base", - output_dir=output_dir, + config=RobertaTrainerConfig( + num_train_epochs=1, + tokenizer_name_or_dir="roberta-base", + output_dir=output_dir, + ), ) - trainer.train(num_train_epochs=1) + trainer.train() assert (output_dir / "config.json").exists() diff --git a/test.py b/test.py deleted file mode 100644 index c4a06ec..0000000 --- a/test.py +++ /dev/null @@ -1,14 +0,0 @@ -from datasets import Dataset - -ds = Dataset.from_list( - [ - {"text": "This is a test."}, - {"text": "This is another test."}, - {"text": "This is yet another test."}, - ] -) - -# convert "text" to "sentence" -ds2 = ds.map(lambda example: {"sentence": example["text"]}) - -print(type(ds))