diff --git a/.gitignore b/.gitignore index b2bad75b..b090496d 100644 --- a/.gitignore +++ b/.gitignore @@ -150,6 +150,7 @@ dmypy.json # and standard hidden files ignore. Including # example files generated via notebook tutorials .* +scratch/ model/ dataset/ datapath/ diff --git a/RWKV-v5/config-example.yaml b/RWKV-v5/config-example.yaml index 28577f14..697b71d5 100644 --- a/RWKV-v5/config-example.yaml +++ b/RWKV-v5/config-example.yaml @@ -341,6 +341,23 @@ data: # If using relative path, this should be relative to the trainer script path data_path: /path/to/store/your/data_path/ + # Data path storage options, this is used to support cloud storage + # via the huggingface dataset API. See: + # https://huggingface.co/docs/datasets/v2.16.1/en/filesystems#amazon-s3 + # + # Note: As of Jan 2023, these options has been only tested to work with AWS S3, and backblaze. YMMV + # For S3 bucket support you will also need to install s3fs `python3 -m pip install s3fs` + # + # If you want to reduce the risk of accidental key/secret commits, you can use + # `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables instead + # + # For datapath, it should use the `s3://bucket-name/subpath` format + # --- + # data_path_storage_options: + # key: + # secret: + # endpoint_url: + # Other wise provide the source path, which is used as huggingface dataset path # this will be used to populate the dataset_path # @@ -426,7 +443,6 @@ data: # multi_column_train_mask: [true, false, true] # multi_column_separator: "\n\n" - # Conversation merging process # useful for merging full conversational datasets, into single documents # default is off, (or set conversation_key to []) diff --git a/RWKV-v5/src/data.py b/RWKV-v5/src/data.py index 7c75d65c..03e8b389 100644 --- a/RWKV-v5/src/data.py +++ b/RWKV-v5/src/data.py @@ -31,7 +31,7 @@ def prepare_data_static(**kargs): # Check if skip_datapath_setup is enabled # useful for extra large datasets if kargs["skip_datapath_setup"] == True: - return + return None # Special handling of world_add_endoftext_token (if enabled) if kargs["world_add_endoftext_token"]: @@ -804,8 +804,49 @@ def reverse_dataset(x, idx): return train_dataset[train_dataset.num_rows - idx - 1] src_dataset["train"] = src_dataset["train"].map(reverse_dataset, with_indices=True, num_proc=num_cpus) - # Save the dataset to disk - src_dataset.save_to_disk(kargs["data_path"]) + # # Convert to iterable datasets (does not support saving to disk???) + # src_dataset["train"] = src_dataset["train"].to_iterable_dataset() + # src_dataset["test"] = src_dataset["test"].to_iterable_dataset() + + # Save the dataset to disk (if enabled) + # For the skip datapath saving string + # We intentionally used several filesystem illegal characters, to ensure it + # is not accidentally used by the user for a real file + if kargs["data_path"] != ".//<#|=@%!$skip_datapath$!%@=|#>//.": + if kargs["data_path_storage_options"]: + + # import s3fs + # fs = s3fs.S3FileSystem( + # key=kargs["data_path_storage_options"]["key"], + # secret=kargs["data_path_storage_options"]["secret"], + # endpoint_url=kargs["data_path_storage_options"]["endpoint_url"], + # client_kwargs={ + # 'region_name': 'sfo3' + # }, + # # asynchronous=True, + # config_kwargs={ + # 'signature_version': 's3v4', + # 's3': { + # 'addressing_style': 'virtual' + # } + # } + # ) + # print("fs.ls", fs.ls("")) + + src_dataset.save_to_disk( + kargs["data_path"], + storage_options=kargs["data_path_storage_options"] + ) + else: + src_dataset.save_to_disk( + kargs["data_path"] + ) + + # Return the dataset object itself + return src_dataset + else: + # there is nothing, return none + return None # Dataloader collator for merging multiple dataset records together # we use token 0 for padding, with a learning mask value of 0 @@ -855,6 +896,11 @@ def __init__( self, # load_from_disk(dataset_path) param data_path: str, + # Data path storage options, this is used to support cloud storage + # via the huggingface dataset API. See: + # https://huggingface.co/docs/datasets/v2.16.1/en/filesystems#amazon-s3 + # Note: As of Jan 2023, these options seems very buggy, YMMV + data_path_storage_options:dict = None, # load_dataset(path) param source: str = None, # load_dataset(data_dir) param @@ -871,7 +917,7 @@ def __init__( # --- # Tokenizer settings # --- - tokenizer: str = "neox", + tokenizer: str = "world", autoTokenizer = None, # Add <|endoftext|> string token to the world tokenizer, at index 0 @@ -892,9 +938,6 @@ def __init__( sort_by_length: bool = False, sort_asc: bool = True, - # Dataloader shuffling, disabled if "sort_by_length" is enabled - training_dataloader_shuffle_auto: bool = True, - # Dataset offset and limit controls dataset_offset: float = -1, dataset_length: float = -1, @@ -981,13 +1024,23 @@ def __init__( # System tweaks # ---------------------------- + # Skip database setup checks if datapath exists, ignored if using preload_datapath.py + skip_datapath_setup: bool = False, + # Batch size scanning range, used for deciding the max number of documents # to process simultaneously at a time. This is used to prevent OOM errors # while rearranging the dataset, etc. Used for both packing / sorting operations - processing_max_batch_size: int = 100000, + processing_max_batch_size: int = 100 * 1000, - # Skip database setup checks if datapath exists, ignored if using preload_datapath.py - skip_datapath_setup: bool = False + # Dataloader shuffling, disabled if "sort_by_length" is enabled + dataloader_shuffle_training: bool = False, + + # With a total of 4 batches prefetched into memory + dataloader_prefetch_factor:int = 4, + + # Pin the preloaded documents into GPU memory in advance + # very small overhead, slight speed bump, disable if your deperate for vram + dataloader_pin_memory: bool = True, ): # Capture the init parameters self._init_locals = locals() @@ -996,9 +1049,13 @@ def __init__( super().__init__() self.data_path = data_path - self._loaded_dataset = None + self.data_path_storage_options = data_path_storage_options + self.dataloader_prefetch_factor = dataloader_prefetch_factor + self.dataloader_pin_memory = dataloader_pin_memory + self.dataloader_shuffle_training = dataloader_shuffle_training self.sort_by_length = sort_by_length - self.training_dataloader_shuffle_auto = training_dataloader_shuffle_auto + + self._loaded_dataset = None # Log to wandb if wandb.run is not None: @@ -1011,7 +1068,10 @@ def prepare_data(self): # Setup process that is universal def _internal_setup(self): if self._loaded_dataset is None: - self._loaded_dataset = load_from_disk(self.data_path).with_format('torch') + if self.data_path_storage_options: + self._loaded_dataset = load_from_disk(self.data_path, storage_options=self.data_path_storage_options).with_format('torch') + else: + self._loaded_dataset = load_from_disk(self.data_path).with_format('torch') # Called once for every process in DDP def setup(self, stage): @@ -1023,7 +1083,7 @@ def train_dataloader(self): dataset = self._loaded_dataset['train']; sampler = DistributedSampler( dataset, - shuffle=self.training_dataloader_shuffle_auto and not self.sort_by_length, + shuffle=self.dataloader_shuffle_training and not self.sort_by_length, num_replicas=self.trainer.world_size, rank=self.trainer.global_rank, ) @@ -1038,14 +1098,14 @@ def train_dataloader(self): shuffle=False, # 4 prefetch workers per GPU num_workers=4, - # Prefetching 8 batches - prefetch_factor=8, - # Of batch size 1 datasets + # Prefetching of X batches + prefetch_factor=self.dataloader_prefetch_factor, + # Of batch sizeed datasets batch_size=microbatch_size, # The collation function collate_fn=dataloader_collator_fn, # Pinned in GPU memory - pin_memory=True + pin_memory=self.dataloader_pin_memory ) # Return the validation dataloader @@ -1058,6 +1118,11 @@ def val_dataloader(self): num_replicas=self.trainer.world_size, rank=self.trainer.global_rank, ) + + microbatch_size = 1 + if hasattr(self, "trainer") and hasattr(self.trainer, "microbatch_size"): + microbatch_size = self.trainer.microbatch_size + return DataLoader( dataset, sampler=sampler, @@ -1065,9 +1130,11 @@ def val_dataloader(self): # 4 prefetch workers per GPU num_workers=4, # Prefetching 8 batches - prefetch_factor=8, - # Of batch size 1 datasets - batch_size=1, + prefetch_factor=self.dataloader_prefetch_factor, + # Of batch sized datasets + batch_size=microbatch_size, + # The collation function + collate_fn=dataloader_collator_fn, # Pinned in GPU memory - pin_memory=True + pin_memory=self.dataloader_pin_memory ) \ No newline at end of file diff --git a/RWKV-v5/src/model.py b/RWKV-v5/src/model.py index e16dbfd5..f71f46c0 100644 --- a/RWKV-v5/src/model.py +++ b/RWKV-v5/src/model.py @@ -183,6 +183,9 @@ def __init__(self, lr_final: float = -1.0, lr_period: int = -1, lr_period_type: str = 'epoch', + # Use either "cosine" or "linear" + lr_type: str = 'cosine', + # Dropout rate dropout: float = 0.0, # Adam optimizer settings @@ -271,6 +274,7 @@ def __init__(self, self.lr_final = lr_final self.lr_period = lr_period self.lr_period_type = lr_period_type + self.lr_type = lr_type self.dropout = dropout self.warmup_steps = warmup_steps self.beta1 = beta1 @@ -516,17 +520,26 @@ def configure_optimizers(self): if self.lr_period_type == "step": lr_total_step = self.lr_period elif self.lr_period_type == "epoch": - lr_total_step = self.lr_period * self.num_step_per_epoch() * self.trainer.num_devices # * self.trainer.microbatch_size + lr_total_step = self.lr_period * self.num_step_per_epoch() # * self.trainer.microbatch_size else: raise ValueError(f"lr_period_type {self.lr_period_type} not supported.") # Lets initialize the lr_scheduler - lr_scheduler = torch.optim.lr_scheduler.LinearLR( - optimizer, - start_factor=1.0, - end_factor= lr_final / lr_init, - total_iters=lr_total_step - ) + if self.lr_type == "cosine": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=lr_total_step, + eta_min=lr_final + ) + elif self.lr_type == "linear": + lr_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1.0, + end_factor= lr_final / lr_init, + total_iters=lr_total_step + ) + else: + raise ValueError(f"lr_type {self.lr_type} not supported.") return { 'optimizer': optimizer, @@ -566,7 +579,8 @@ def num_step_per_epoch(self) -> int: dataset_size = len(train_dataloader) num_devices = max(1, self.trainer.num_devices) - num_steps = dataset_size // (self.trainer.accumulate_grad_batches * num_devices) + num_nodes = max(1, self.trainer.num_nodes) + num_steps = dataset_size // (self.trainer.accumulate_grad_batches * num_devices * num_nodes) return num_steps @property