From b5915a84f611e5a4ff80fcbc8e4f69e7c649b0e8 Mon Sep 17 00:00:00 2001 From: Ryo Igarashi Date: Thu, 4 Jul 2024 11:49:05 +0900 Subject: [PATCH] =?UTF-8?q?hf-dataset-repo=E3=82=92job=5Fspec=E3=81=AE?= =?UTF-8?q?=E5=AE=9A=E7=BE=A9=E3=81=AB=E3=82=82=E4=BD=BF=E3=81=86=E3=82=88?= =?UTF-8?q?=E3=81=86=E3=81=AB=E3=81=97=E3=81=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/ainu_lm_pipeline/components/get_mt5_training_job_spec.py | 3 ++- src/ainu_lm_pipeline/pipelines/ainu_mt5_pipeline.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ainu_lm_pipeline/components/get_mt5_training_job_spec.py b/src/ainu_lm_pipeline/components/get_mt5_training_job_spec.py index 2c9f969..fef8838 100644 --- a/src/ainu_lm_pipeline/components/get_mt5_training_job_spec.py +++ b/src/ainu_lm_pipeline/components/get_mt5_training_job_spec.py @@ -5,6 +5,7 @@ def get_mt5_training_job_spec( train_image_uri: str, push_to_hub: bool, + dataset_name: str, dataset_revision: str, ) -> list: worker_pool_specs = [ @@ -13,7 +14,7 @@ def get_mt5_training_job_spec( "image_uri": train_image_uri, "args": [ "mt5", - "--dataset-name=aynumosir/ainu-corpora-normalized", + f"--dataset-name={dataset_name}", "--dataset-split=train", f"--dataset-revision={dataset_revision}", f"--push-to-hub={push_to_hub}", diff --git a/src/ainu_lm_pipeline/pipelines/ainu_mt5_pipeline.py b/src/ainu_lm_pipeline/pipelines/ainu_mt5_pipeline.py index 2fcca5f..9280314 100644 --- a/src/ainu_lm_pipeline/pipelines/ainu_mt5_pipeline.py +++ b/src/ainu_lm_pipeline/pipelines/ainu_mt5_pipeline.py @@ -97,6 +97,7 @@ def ainu_mt5_pipeline( get_mt5_training_job_spec( train_image_uri=train_image_uri, push_to_hub=push_to_hub, + dataset_name=hf_dataset_repo, dataset_revision=get_dataset_revision_op.output, ) .after(build_custom_train_image_op)