diff --git a/RWKV-v5/config-example.yaml b/RWKV-v5/config-example.yaml index 340fd731..9b191aec 100644 --- a/RWKV-v5/config-example.yaml +++ b/RWKV-v5/config-example.yaml @@ -156,6 +156,17 @@ trainer: # your GPU processing time %, and avoid idle time for the GPU between batches target_batch_size: 32 + # Microbatching chunks which we split our data by, this substentially increase vram usage + # for each GPU step, but increase throughput of the training process substentially. + # + # So if you have 16 datasample per batch per GPU. And microbatch_size of 2, you have 8 substep + # + # It is generally recommended to tune this to be the highest you can resonably support + # on your GPU as it has a direct impact on your overall tokens / second count. + # + # Typically you tune the microbatch_size first, before tuning the target_batch_size + microbatch_size: 1 + # You can alternatively set the accumulate_grad_batches per GPU directly # (not recommended) # @@ -409,6 +420,12 @@ data: # A minimum of 2 columns is required, with non empty data, for the merge to occur # If no match is found, this will fallback to the default prompt/completion or text column, # or throw an error if the default fallback is not found + # + # IMPORTANT NOTE: as newlines are commonly used for multi_column_suffix, etc. + # you should use single quotes to ensure such values dun get escaped. + # eg. multi_column_suffix: ['\n\n'] + # + # See: https://github.com/RWKV/RWKV-infctx-trainer/issues/34 # --- # multi_column_keys: ['instruction', 'input', 'output'] # multi_column_prefix: ['Instruction:\n', 'Input:\n', 'Output:\n'] diff --git a/RWKV-v5/src/data.py b/RWKV-v5/src/data.py index 38d9f4f0..8ba9cae1 100644 --- a/RWKV-v5/src/data.py +++ b/RWKV-v5/src/data.py @@ -1,5 +1,6 @@ from lightning import LightningDataModule +import torch from torch.utils.data import DataLoader from torch.utils.data import DistributedSampler @@ -293,10 +294,15 @@ def map_tokenizer(x): input_ids += column_encodings['input_ids'] token_type_ids += column_encodings['token_type_ids'] - # Override the training attention mask if masking is set to false - if len(multi_column_train_mask) < i and multi_column_train_mask[i] is False: + # Configure the attention masks accordingly + if i > len(multi_column_train_mask): + # If the corresponding `multi_column_train_mask` is not set, we will assume as valid training data + attention_mask += ([1] * len(column_encodings['input_ids'])) + elif multi_column_train_mask[i] is False: + # If the `multi_column_train_mask` is set, but configured as false, we should not pay attention to it attention_mask += ([0] * len(column_encodings['input_ids'])) - else: + else: # multi_column_train_mask[i] is True + # This means it is true, lets pay attention once again attention_mask += ([1] * len(column_encodings['input_ids'])) # Add the suffix @@ -494,6 +500,48 @@ def add_length(example): # Save the dataset to disk src_dataset.save_to_disk(kargs["data_path"]) +# Dataloader collator for merging multiple dataset records together +# we use token 0 for padding, with a learning mask value of 0 +def dataloader_collator_fn(records): + # Get the maximum number of records + # (aka the batch size) + records_len = len(records) + + # Compute the total length of the records + input_ids_len = 0 + token_type_ids_len = 0 + attention_mask_len = 0 + + # Loop through the records and compute the max length + for i in range(records_len): + input_ids_len = max(input_ids_len, len(records[i]["input_ids"])) + token_type_ids_len = max(token_type_ids_len, len(records[i]["token_type_ids"])) + attention_mask_len = max(attention_mask_len, len(records[i]["attention_mask"])) + + # First row of the records + first_row = records[0] + + # Create the output arrays, with the default 0 values (no learning mask) + out_input_ids = torch.zeros((records_len, input_ids_len), dtype=first_row["input_ids"].dtype) + out_token_type_ids = torch.zeros((records_len, token_type_ids_len), dtype=first_row["token_type_ids"].dtype) + out_attention_mask = torch.zeros((records_len, attention_mask_len), dtype=first_row["attention_mask"].dtype) + out_data_ctx_len = torch.zeros((records_len), dtype=torch.int32) + + # Loop through the records and copy the values to the output arrays + for i in range(records_len): + out_input_ids[i][:len(records[i]["input_ids"])] = records[i]["input_ids"] + out_token_type_ids[i][:len(records[i]["token_type_ids"])] = records[i]["token_type_ids"] + out_attention_mask[i][:len(records[i]["attention_mask"])] = records[i]["attention_mask"] + out_data_ctx_len[i] = len(records[i]["input_ids"]) + + # Build & return the output object + out = { + 'input_ids': out_input_ids, + 'token_type_ids': out_token_type_ids, + 'attention_mask': out_attention_mask, + 'data_ctx_len': out_data_ctx_len + } + return out class RWKVDataModule(LightningDataModule): def __init__( @@ -595,6 +643,11 @@ def train_dataloader(self): num_replicas=self.trainer.world_size, rank=self.trainer.global_rank, ) + + microbatch_size = 1 + if hasattr(self, "trainer") and hasattr(self.trainer, "microbatch_size"): + microbatch_size = self.trainer.microbatch_size + return DataLoader( dataset, sampler=sampler, @@ -604,7 +657,9 @@ def train_dataloader(self): # Prefetching 8 batches prefetch_factor=8, # Of batch size 1 datasets - batch_size=1, + batch_size=microbatch_size, + # The collation function + collate_fn=dataloader_collator_fn, # Pinned in GPU memory pin_memory=True ) diff --git a/RWKV-v5/src/model.py b/RWKV-v5/src/model.py index 4d62b281..0773b6d1 100644 --- a/RWKV-v5/src/model.py +++ b/RWKV-v5/src/model.py @@ -135,21 +135,21 @@ def forward(ctx, loss, y, token_amount, currentMask): # # See also: # - checkpointed_step - ctx.save_for_backward(y) - ctx.token_amount = token_amount - ctx.currentMask = currentMask + ctx.save_for_backward(y, token_amount, currentMask) return loss @staticmethod def backward(ctx, grad_output): - y, = ctx.saved_tensors - token_amount = ctx.token_amount + y, token_amount, currentMask = ctx.saved_tensors + # to encourage the logits to be close to 0 factor = 1e-4 / token_amount maxx, ids = torch.max(y, -1, keepdim=True) gy = torch.zeros_like(y) gy.scatter_(-1, ids, maxx * factor) - gy = gy * ctx.currentMask[:, None][None, :] + + # We ensure the mask is reshaped accordingly, and apply it against gy + gy = gy * currentMask.reshape(gy.shape[0],gy.shape[1],1) # currentMask[:, None][None, :] return (grad_output, gy, None, None) ### --- @@ -200,7 +200,7 @@ def __init__(self, grad_cp: bool = True, bptt_learning: bool = True, bptt_learning_range: int = -1, - bptt_truncated_learning: bool = False, + bptt_truncated_learning: bool = True, layerwise_lr: bool = True, dim_att: Optional[int] = None, dim_ffn: Optional[int] = None, @@ -784,7 +784,8 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool): self._counting_tokens = 0 if self._counting_time_start is None or batch_idx == 0: self._counting_time_start = time.time() - + + # Get the input sequence, and attention mask seq = batch['input_ids'] assert isinstance(seq, torch.Tensor) and seq.ndim == 2 ori_seq_mask = batch['attention_mask'] @@ -793,17 +794,30 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool): if ori_seq_mask is None or ori_seq_mask.ndim != 2: ori_seq_mask = torch.ones_like(seq[:, 1:]) + # Initialize the total_mask_sum (but not compute it) + total_mask_sum = 0 + + # Number of GPUs used in training, note that if it is > 1 + # it is requried that all operations here are in sync with + # all other GPUs, as such "quick return" on this function + # should not be allowed + num_devices = self.trainer.num_devices + + ### --- + ### Positional loss bias handling + ### --- + # Get the starting and ending loss bias loss_bias_start = self.position_loss_bias loss_bias_end = 2.0 - loss_bias_start - # total_mask_sum - total_mask_sum = torch.sum(ori_seq_mask) - # Skip loss bias calculation, if loss_bias_start is 1.0 if loss_bias_start == 1.0 or (is_training_run == False and self.position_loss_bias_in_validation == False): seq_mask = ori_seq_mask else: + # Lets get the torch mask sum + total_mask_sum = torch.sum(ori_seq_mask) + # Lets get a linear multiplier for the loss bias # seq_mask_sum = torch.sum(ori_seq_mask) bias_mask = torch.linspace(loss_bias_start, loss_bias_end, int(total_mask_sum.item()), device=ori_seq_mask.device) @@ -818,12 +832,18 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool): # And save it as seq_mask seq_mask = final_mask.unsqueeze(0) + ### --- + ### Training cutoff logic handling + ### --- + # Perform cutoff for training run if is_training_run: prev_step = 0 # Avoid using the zip operation, as torch.compile throws an exception on it # with `zip not reconized as a valid function` + # + # This skip if ctx_len_warmup_steps/ctx_len_cutoffs is not set # --- # for step, len_cut in zip(self.ctx_len_warmup_steps, # self.ctx_len_cutoffs): @@ -846,23 +866,35 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool): seq_mask[:, :pos] = 0 break prev_step = step - + + ### --- + ### Various size checking, and implementing the core checkpoint_step + ### --- + + # BPTT, and training steps, and various size fetching do_bptt_learning = self.bptt_learning and is_training_run idx, targets = seq[:, :-1], seq[:, 1:] - B, T = idx.shape C = self.n_embd - total_mask_sum = torch.sum(seq_mask) # If total_mask_sum, we skip, as there is no tokens of value to learn from anyway - if total_mask_sum == 0: + total_mask_sum = torch.sum(seq_mask) + # Do a quick return, if there is no tokens of value to learn from due to full masking + if num_devices > 1 and total_mask_sum == 0: return 0 + # Checkpoint steps def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states, last_wkv_states, prev_steps): logits, new_shift_states, new_wkv_states = self( idx, last_shift_states, last_wkv_states) + # Ensure logits, targets, and mask are contiguous + # this is required to avoid view is not compatible with size and stride error + logits = logits.contiguous() + targets = targets.contiguous() + mask = mask.contiguous() + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction="none") @@ -876,14 +908,17 @@ def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states, new_loss = prev_loss + loss return new_loss, new_shift_states, new_wkv_states, new_steps - total_loss = torch.tensor( - 0, dtype=self.emb.weight.dtype).requires_grad_() + total_loss = torch.tensor(0, dtype=self.emb.weight.dtype).requires_grad_() steps = 0 states = BlockStateList.create(self.n_layer, B, C, self.n_head, self.head_size, seq.device, self.emb.weight.dtype) segment_count = math.ceil(T / self.ctx_len) + ### --- + ### Learning process logic (BPTT or not) + ### --- + # # BPTT learning, we split the sequence into segments # and perform a backward pass for each segment, on its own. @@ -919,12 +954,12 @@ def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states, # it also helps ensure the segment cutoff points are more varied, across mixed dataset sizes # and avoid potentially undesired training behaviour at fixed cutoff points # (this only applies for segmented learning) - segment_size = min(math.ceil(T / segment_count), self.ctx_len) + segment_size = min(math.ceil(T / segment_count)+1, self.ctx_len) - # Dummy 2D tenros of shape [1,1], are used to do "dummy checkpoint/forward/backprop" to keep everything in sync + # Dummy 2D tensor of shape [1,1], are used to do "dummy checkpoint/forward/backprop" to keep everything in sync dummy_2d_zero = torch.tensor([[0]], dtype=torch.long, device=cur_device) - # Get the max segment count across all GPUs, in the current batch, which is used to keep all devices are in sync + # Get the max segment count across all GPUs, in the current substep, which is used to keep all devices are in sync # Once a thread has completed all its segments, it will do dummy checkpoint/forward/backprop with one token, # and stay in sync with the thread that are still working on their segments # @@ -1118,13 +1153,20 @@ def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states, global_rank = self.global_rank global_device_count = self.trainer.num_devices * self.trainer.num_nodes + # Get the total dataset context length + batch_ctx_len = 0 + if "data_ctx_len" in batch: + batch_ctx_len = torch.sum(batch["data_ctx_len"]).item() + else: + batch_ctx_len = T * self.trainer.microbatch_size + # Increment the counting tokens, and log it accordingly - self._counting_tokens += T + self._counting_tokens += batch_ctx_len # Log the line values wandb.log({ 'global_rank': global_rank, - 'real_ctx_len': T, + 'data_ctx_len': batch_ctx_len / self.trainer.microbatch_size, 'train/loss': total_loss, f'perf/tokens_total.gpu.{global_rank}': self._counting_tokens, f'perf/tokens_per_sec.gpu.{global_rank}': self._counting_tokens / max(time.time() - self._counting_time_start, 1), @@ -1138,8 +1180,15 @@ def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states, assert not torch.isnan(total_loss), "total_loss is NaN" return total_loss + # + # Training and validation steps + # @TCompileBaseline def training_step(self, batch, batch_idx): + + # print("=== BATCH ID SHAPE ===", batch["input_ids"].shape) + # print("=== BATCH AM SHAPE ===", batch["attention_mask"].shape) + total_loss = self.compute_loss(batch, batch_idx, True) self.log('train/loss', total_loss, prog_bar=True) diff --git a/RWKV-v5/src/trainer.py b/RWKV-v5/src/trainer.py index f54958b7..478b58d6 100644 --- a/RWKV-v5/src/trainer.py +++ b/RWKV-v5/src/trainer.py @@ -18,8 +18,15 @@ def __init__( *args, # Replaces the accumulate_grad_batches, if set # automatically compute the accumulate_grad_batches - # according to the num_nodes, and num_devices configured + # + # According to the microbatch_size, num_nodes, + # and num_devices configured target_batch_size=-1, + # Microbatch sizing, to be used with + # each training step per GPU. + # + # This is the same as pytorch dataset batch size. + microbatch_size=1, # Handle the rest of args, as per normal **kwargs, ): @@ -32,6 +39,10 @@ def __init__( # target batch size logging target_batch_size_log_msg = "" + # Compute the microbatch_size + self.microbatch_size = microbatch_size + assert microbatch_size > 0, "microbatch_size must be greater than 0" + # Compute the accumulate_grad_batches, using the target_batch_size self.target_batch_size = target_batch_size if target_batch_size > 0: @@ -56,9 +67,9 @@ def __init__( raise ValueError(f"Unsupported devices config '{devices}', unable to compute device count for 'target_batch_size'") # Compute the accumulate_grad_batches - accumulate_grad_batches = max( 1, math.floor(target_batch_size / (num_nodes * num_devices)) ) + accumulate_grad_batches = max( 1, math.floor(target_batch_size / (num_nodes * num_devices * microbatch_size)) ) kwargs["accumulate_grad_batches"] = accumulate_grad_batches - effective_batch_size = accumulate_grad_batches * num_nodes * num_devices + effective_batch_size = accumulate_grad_batches * num_nodes * num_devices * microbatch_size # Log the applied accumulate_grad_batches trainer_config["__accumulate_grad_batches"] = accumulate_grad_batches @@ -71,6 +82,7 @@ def __init__( f" - target_batch_size: {target_batch_size}\n"+ f" - num_nodes: {num_nodes}\n"+ f" - num_devices: {num_devices}\n"+ + f" - microbatch_size: {microbatch_size}\n"+ f" - accumulate_grad_batches: {accumulate_grad_batches}\n" f" - effective_batch_size: {effective_batch_size}\n") @@ -99,7 +111,7 @@ def __init__( # if local rank is 0 if target_batch_size_log_msg != "" and self.local_rank == 0: print(target_batch_size_log_msg) - + # Fabric instance, useful for coordinating between processes # when `self.trainer.strategy.reduce` is not possible def getFabric(self): diff --git a/notebook/dataset-config/example-hf-multi-column-keys.yaml b/notebook/dataset-config/example-hf-multi-column-keys.yaml index f8ba1abf..95034186 100644 --- a/notebook/dataset-config/example-hf-multi-column-keys.yaml +++ b/notebook/dataset-config/example-hf-multi-column-keys.yaml @@ -102,6 +102,12 @@ data: # A minimum of 2 columns is required, with non empty data, for the merge to occur # If no match is found, this will fallback to the default prompt/completion or text column, # or throw an error if the default fallback is not found + # + # IMPORTANT NOTE: as newlines are commonly used for multi_column_suffix, etc. + # you should use single quotes to ensure such values dun get escaped. + # eg. multi_column_suffix: ['\n\n'] + # + # See: https://github.com/RWKV/RWKV-infctx-trainer/issues/34 # --- multi_column_keys: ['instruction', 'input', 'output'] multi_column_prefix: ['Instruction:\n', 'Input:\n', 'Output:\n'] diff --git a/notebook/trainer-v5-unit-test/config/enwiki_10k-world-4x1024.yaml b/notebook/trainer-v5-unit-test/config/enwiki_10k-world-4x1024.yaml new file mode 100644 index 00000000..d7672172 --- /dev/null +++ b/notebook/trainer-v5-unit-test/config/enwiki_10k-world-4x1024.yaml @@ -0,0 +1,271 @@ +# lightning.pytorch==2.0.2 +seed_everything: 3941088705 +trainer: + + # + # Configure the deepspeed strategy, we recommend you start with `deepspeed_stage_2_offload` + # and adjust from there according to your training needs. `deepspeed_stage_3_offload` is useful + # for training LoRA on large models on a single GPU. + # + # In general you would want to use the following: + # + # - deepspeed_stage_1 : Each of your GPU has too much vram, and you do not know what to do + # + # - deepspeed_stage_2 : Optimal distributed training strategy, across multiple gpu each with sufficient vram + # - deepspeed_stage_2_offload : Reduce vram usage by offloading the optimizer state and work to cpu + # + # - deepspeed_stage_3 : Split up the model across multiple gpu, useful for large models, at a performance cost + # - deepspeed_stage_3_offload : Additional offloading, for even greater performance cost + # + # For more details see: + # https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed-zero-stage-2 + # + strategy: deepspeed_stage_2_offload + + # Logger setting for wandb, if you want to enable wandb, uncomment the whole logger section + # --- + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + name: 'infctx-v5-unit-test-baseline (train-ctx=1024, data-ctx=4096)' + project: 'RWKV-infctx-unit-test' + tags: ['RWKV', 'infctx'] + + # Checkpoint settings for the training process + callbacks: + class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + # Configure this to the path you want to save your checkpoints to + # note that a subdir will be created with the name `epoch=x-step=y.ckpt` + # + # to convert a checkpoint to a model, you can use the + # `python3 export_checkpoint.py ` script, + # which will create a `rwkv_model.pth` in the checkpoint directory. + # + # Do not use the `zero_to_fp32.py` script as that will have export format issues + dirpath: ../checkpoint/infctx-v5-unit-test-baseline-4x1024 + filename: null + + # Save the top/last K checkpoints + save_top_k: 3 + # Choose by the most recent checkpoints (step based) + monitor: 'step' + mode: max + + # If enabled (true), save a copy of the latest checkpoint to 'last.ckpt' + # useful to simply checkpoint resume scripts, at a price of disk performance + save_last: true + + # DO NOT set this as true, as the model weight exported will have format issues + # expert as checkpoint, and use the `export_checkpoint.py` script to convert to model instead + save_weights_only: false + + # How frequent you want to save a checkpoint for every step. + # This will happen for every X data sample, where X = every_n_train_steps * accumulate_grad_batches + # + # In general you will want to avoid putting a low number (expecially if accumulate_grad_batches <= 100) + # as the checkpoint process, will pause all the gpu training for some time, slowing down the overall process + # However you do not want to configure too high of a number, where you will lose too much progress if the training crashes + every_n_train_steps: 10 + every_n_epochs: null + save_on_train_epoch_end: true + train_time_interval: null + + # Other settings, you can probably leave alone + verbose: false + auto_insert_metric_name: true + + ######################################## + ## Training run parameter settings + ######################################## + + # Generally what you want to configure is the maximum number of epochs + # Leave it as -1, and it will keep going forever till interrupted + # Or set it as a number, and it will stop after that number of epochs + max_epochs: 1 + min_epochs: null + max_steps: -1 + min_steps: null + max_time: null + + # Number of datasamples to train for each step, a data sample is considered + # a "substep" in wandb logs, and a "step" is tracked as "trainer/global_step" + # + # This decides the number of datasample, to learn together from, before backproping + # any weight changes at the end of the batch. + # + # Recommended to be a big enough number (like 128/256) where it prevents the training + # loss from flucuating in the process. But not too big of a number where the increased + # GPU vRAM / offloaded RAM usage will cause the training to crash. + # + # You are also recommended to configure this to a large enough number to fully utilize + # your GPU processing time %, and avoid idle time for the GPU between batches + # + # This number is divided by the number of GPUs, and nodes configured + # So if you have 4 GPUs, 2 nodes, and this is configured as 128 + # Each GPU will process 128/4/2 = 16 datasamples per step batch, via accumulate_grad_batches + target_batch_size: 16 + + # Microbatching chunks which we split our data by, this substentially increase vram usage + # for each GPU step, but increase throughput of the training process substentially. + # + # So if you have 16 datasample per batch per GPU. And microbatch_size of 2, you have 8 substep + microbatch_size: 2 + +######################################## +## Training model settings +######################################## +model: + # Model to start the finetune/training process from + load_model: ../model/L24-D2048-neox-v5base-init.pth + + # Context length to use for the training process + # the larger the number (and batch size) the larger the vram usage + # + # Note that if the datasample context length is larger then the ctx_len + # its training process would be split into ctx_len sized chunks. + # + # This allows the training of extreamly large context length (eg. 100k), + # without eating up too much vram by keeping the training context length + # to a resonable number sutible to the current GPU setup + ctx_len: 1024 + + # Data samples would be cut down to the respective max ctx_len_cutoffs + # values if its larger then ctx_len. If the data sample is larger then + # the largest len_cutoff, the remaining data will be discarded + ctx_len_cutoffs: [] + # Experimental settings, number of tokens to skip in the data sample + # prefix, for the respective cutoff length. Used to speed up the process + ctx_len_warmup_steps: [] + + # Learning rate of the training process + # --- + + # Initia learning rate of the process + lr_init: 8e-4 + # Final learning rate after the learning rate period + # learning rate will stay at final value from then onwards + lr_final: 4e-4 + + # Number of epoch to reduce the learning rate from lr_init to lr_final + # 1 means a single epoch (so lr would be lr_final from epoch 2 onwards) + # 0 means lr_final will apply immediately + # -1 means we take the current max_step / max_epoch as the period + lr_period: 1 + # lr_period type if its set, defaults to epoch + lr_period_type: epoch + + # Adam optimizer settings + # You probably want to leave this alone, unless you know what you are doing + beta1: 0.9 + beta2: 0.99 + adam_eps: 1.0e-08 + weight_decay: 0.01 + + # torch.set_float32_matmul_precision, used to optimize operations with tensor cores + # this should be set as null, for non cuda core GPUs + torch_set_float32_matmul_precision: 'high' + # torch_set_float32_matmul_precision: null + + # Segmented based learning, used to work around training of large context length + # beyond what can be supported by the current GPU vram architecture + # + # This is not 1:1 equivalent to the same training process with required vram + # as the training process is split into multiple segments, part by part. + # with limited learnings from the previous segment. + bptt_learning: true + + # Segmented range to performing backprop learning on + # 1 means to apply only for the last segment + # -1 means to apply for all segments + bptt_learning_range: -1 + +data: + # dataset_path for the prebuilt dataset, using HF `load_from_disk()` + # + # Use this if you have built your own dataset and saved it with `save_to_disk()` + # with source left as null. Other wise configure this to a directory which the + # dataset will be built and tokenized by the huggingface dataset process. + data_path: ../datapath/enwiki_10k-world-4096/ + + # Other wise provide the source path, which is used as huggingface dataset path + # this will be used to populate the dataset_path + # + # Use either the following + # - hugging face dataset + # - Directory path to a directory containing dataset files + # - Path to a single dataset file + # - hugging face dataset mode (ie: text,csv,etc - use data_dir, to configure the path then) + # - null + # + # If source is disabled, all other params, except data_path, is ignored + source: "teven/enwiki_10k" + # source: text + # source: /home/ubuntu/RWKV-LM-LoRA/dataset-text/enwik8.txt + + # Use data_dir, if you are using source=text/json/etc + # this should be relative to the trainer script path + source_data_dir: null + + # After loading the dataset, split out test data used for unit-test, + # This process is skipped if the dataset includes a test split + # This process is skipped if set to zero + test_split: 0.01 + test_split_shuffle: false + + # Tokenizer to use, use either the inbuilt 'neox', or 'world' tokenizer + # If using a custom tokenizer, provide the tokenizer file path + # --- + tokenizer: world + + # Minimum / Maximum token size of the dataset to use + # useful for filtering out small noisy data samples from large datasets + # (eg. removal of small articles of less then 512 tokens from wikipedia) + # + # This is ignored, if set to -1 + min_token_size: 1024 + max_token_size: -1 + + # Rechunking of text dataset, this is done only when source is set as 'text' + # and will merge the various sentencees, into larger chunks up to the target size + # + # Defaults to 4096 + # + # This is ignored, if source is not set as text + # This is ignored, if set to zero + # --- + text_rechunk_size: 4096 + + # Apply text rechunk to the dataset, even if its not a 'text' source + # This is done only after dataset filtering, and if source is not 'text' + # --- + text_rechunk_force: true + + # Custom text column to use, useful for dataset with alternative training columns labels + # This is checked before multi column merging, default is null (disabled) + # eg: 'code' + # --- + # custom_text_key: 'code' + + # Multi Column merging process, default setting is used to support and merge + # "instruction", "input", "output", datasets. To disable set multi_column_keys to [] + # + # A minimum of 2 columns is required, with non empty data, for the merge to occur + # If no match is found, this will fallback to the default prompt/completion or text column, + # or throw an error if the default fallback is not found + # --- + # multi_column_keys: ['instruction', 'input', 'output'] + # multi_column_prefix: ['Instruction:\n', 'Input:\n', 'Output:\n'] + # multi_column_train_mask: [true, false, true] + # multi_column_separator: '\n\n' + + # If processing prompt/completion jsonl pairs, the prompt is masked by default + # use this flag to disable this default behaviour + # --- + # disable_prompt_completion_mask: false + +# Path to the current checkpoint to continue training from +# Enable this to the last checkpoint after the first run +# (if it crash and you want to resume) +# ckpt_path: ../checkpoint/trainer-validaiton/infctx-unit-test-baseline/epoch=0-step=20.ckpt +ckpt_path: null diff --git a/notebook/trainer-v5-unit-test/short-enwiki-train.ipynb b/notebook/trainer-v5-unit-test/short-enwiki-train.ipynb index 19f6e28e..ef857267 100644 --- a/notebook/trainer-v5-unit-test/short-enwiki-train.ipynb +++ b/notebook/trainer-v5-unit-test/short-enwiki-train.ipynb @@ -28,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -37,9 +37,9 @@ "text": [ "ENABLE_WANDB: False\n", "GPU_DEVICES: auto\n", - "NOTEBOOK_DIR: /home/harrison/Documents/RWKV-infctx-trainer/notebook/trainer-v5-unit-test\n", - "TRAINER_DIR: /home/harrison/Documents/RWKV-infctx-trainer/RWKV-v5\n", - "PROJECT_DIR: /home/harrison/Documents/RWKV-infctx-trainer\n" + "NOTEBOOK_DIR: /home/ubuntu/rwkv-proj/RWKV-infctx-trainer/notebook/trainer-v5-unit-test\n", + "TRAINER_DIR: /home/ubuntu/rwkv-proj/RWKV-infctx-trainer/RWKV-v5\n", + "PROJECT_DIR: /home/ubuntu/rwkv-proj/RWKV-infctx-trainer\n" ] } ], @@ -82,23 +82,15 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "/usr/lib/python3/dist-packages/requests/__init__.py:87: RequestsDependencyWarning: urllib3 (2.0.7) or chardet (4.0.0) doesn't match a supported version!\n", - " warnings.warn(\"urllib3 ({}) or chardet ({}) doesn't match a supported \"\n", - "[2023-11-05 09:14:14,460] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n", - "/home/harrison/.local/lib/python3.10/site-packages/torch/cuda/__init__.py:546: UserWarning: Can't initialize NVML\n", - " warnings.warn(\"Can't initialize NVML\")\n", - "/home/harrison/.local/lib/python3.10/site-packages/torch/cuda/__init__.py:651: UserWarning: CUDA initialization: Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 804: forward compatibility was attempted on non supported HW (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:109.)\n", - " return torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count\n", - "No ROCm runtime is found, using ROCM_HOME='/opt/rocm-5.7.0'\n", - "No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'\n", - "[RWKV.model] Running RWKV model using 'torch-jit' with torch '2.0.1+cu117'\n", + "[2023-11-15 04:13:08,468] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n", + "[RWKV.model] Running RWKV model using 'torch-jit' with torch '2.1.0'\n", "---- Initializing model ----\n", "No of layers: 6\n", "Embedding size: 512\n", @@ -122,42 +114,52 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "/usr/lib/python3/dist-packages/requests/__init__.py:87: RequestsDependencyWarning: urllib3 (2.0.7) or chardet (4.0.0) doesn't match a supported version!\n", - " warnings.warn(\"urllib3 ({}) or chardet ({}) doesn't match a supported \"\n", - "Saving the dataset (1/1 shards): 100%|█| 763/763 [00:00<00:00, 27031.68 examples\n", - "Saving the dataset (1/1 shards): 100%|███| 8/8 [00:00<00:00, 4838.42 examples/s]\n" + "/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'.\n", + " table = cls._concat_blocks(blocks, axis=0)\n", + "Saving the dataset (1/1 shards): 100%|█| 751/751 [00:00<00:00, 11073.50 examples\n", + "Saving the dataset (1/1 shards): 100%|███| 8/8 [00:00<00:00, 2278.89 examples/s]\n" ] } ], "source": [ "# Preload the dataset\n", "!cd \"{TRAINER_DIR}\" && \\\n", - " python3 preload_datapath.py \"{NOTEBOOK_DIR}/config/enwiki_10k-world-4096.yaml\"" + " python3 preload_datapath.py \"{NOTEBOOK_DIR}/config/enwiki_10k-world-4x1024.yaml\"" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "/usr/lib/python3/dist-packages/requests/__init__.py:87: RequestsDependencyWarning: urllib3 (2.0.7) or chardet (4.0.0) doesn't match a supported version!\n", - " warnings.warn(\"urllib3 ({}) or chardet ({}) doesn't match a supported \"\n", - "[2023-11-05 09:58:20,060] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n", - "No ROCm runtime is found, using ROCM_HOME='/opt/rocm-5.7.0'\n", - "[RWKV.model] Running RWKV model using 'torch-jit' with torch '2.0.0+cu117'\n", - "/home/harrison/.local/lib/python3.10/site-packages/lightning/pytorch/cli.py:518: LightningCLI's args parameter is intended to run from within Python like if it were from the command line. To prevent mistakes it is not recommended to provide both args and command line arguments, got: sys.argv[1:]=['fit', '-c', '/home/harrison/Documents/RWKV-infctx-trainer/notebook/trainer-v5-unit-test/config/enwiki_10k-world-4096.yaml', '--trainer.logger.init_args.name=infctx-v5-unit-test (train-ctx=4096, data-ctx=4096, deepspeed_stage_1)', '--trainer.strategy=deepspeed_stage_1', '--trainer.devices=auto', '--model.load_model=../model/L6-D512-world-init.pth'], args=['fit', '-c', '/home/harrison/Documents/RWKV-infctx-trainer/notebook/trainer-v5-unit-test/config/enwiki_10k-world-4096.yaml', '--trainer.logger.init_args.name=infctx-v5-unit-test (train-ctx=4096, data-ctx=4096, deepspeed_stage_1)', '--trainer.strategy=deepspeed_stage_1', '--trainer.devices=auto', '--model.load_model=../model/L6-D512-world-init.pth'].\n", + "[2023-11-15 20:39:11,866] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n", + "[RWKV.model] Running RWKV model using 'torch-jit' with torch '2.1.0'\n", + "/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/cli.py:518: LightningCLI's args parameter is intended to run from within Python like if it were from the command line. To prevent mistakes it is not recommended to provide both args and command line arguments, got: sys.argv[1:]=['fit', '-c', '/home/ubuntu/rwkv-proj/RWKV-infctx-trainer/notebook/trainer-v5-unit-test/config/enwiki_10k-world-4x1024.yaml', '--trainer.logger.init_args.name=infctx-v5-unit-test (train-ctx=1024, data-ctx=4096, deepspeed_stage_1)', '--trainer.strategy=deepspeed_stage_1', '--trainer.devices=auto', '--trainer.max_steps=2', '--model.load_model=../model/L6-D512-world-init.pth'], args=['fit', '-c', '/home/ubuntu/rwkv-proj/RWKV-infctx-trainer/notebook/trainer-v5-unit-test/config/enwiki_10k-world-4x1024.yaml', '--trainer.logger.init_args.name=infctx-v5-unit-test (train-ctx=1024, data-ctx=4096, deepspeed_stage_1)', '--trainer.strategy=deepspeed_stage_1', '--trainer.devices=auto', '--trainer.max_steps=2', '--model.load_model=../model/L6-D512-world-init.pth'].\n", "Seed set to 3941088705\n", + "/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", + " return self.fget.__get__(instance, owner)()\n", + "---\n", + "[RWKV.TimeMix] Compiling CUDA kernel with HEAD_SIZE=64\n", + "Using /home/ubuntu/.cache/torch_extensions/py311_cu118 as PyTorch extensions root...\n", + "Detected CUDA files, patching ldflags\n", + "Emitting ninja build file /home/ubuntu/.cache/torch_extensions/py311_cu118/wkv5/build.ninja...\n", + "Building extension module wkv5...\n", + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", + "ninja: no work to do.\n", + "Loading extension module wkv5...\n", + "[RWKV.TimeMix] CUDA kernel compiled & loaded globally\n", + "---\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", @@ -171,13 +173,14 @@ " - accumulate_grad_batches: 16\n", " - effective_batch_size: 16\n", "\n", - "Saving the dataset (1/1 shards): 100%|█| 763/763 [00:00<00:00, 25624.79 examples\n", - "Saving the dataset (1/1 shards): 100%|███| 8/8 [00:00<00:00, 4457.28 examples/s]\n", + "/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'.\n", + " table = cls._concat_blocks(blocks, axis=0)\n", + "Saving the dataset (1/1 shards): 100%|█| 751/751 [00:00<00:00, 12483.54 examples\n", + "Saving the dataset (1/1 shards): 100%|███| 8/8 [00:00<00:00, 1897.02 examples/s]\n", "[rank: 0] Seed set to 3941088705\n", "initializing deepspeed distributed: GLOBAL_RANK: 0, MEMBER: 1/1\n", "Enabling DeepSpeed BF16. Model parameters and inputs will be cast to `bfloat16`.\n", - "/usr/lib/python3/dist-packages/requests/__init__.py:87: RequestsDependencyWarning: urllib3 (2.0.7) or chardet (4.0.0) doesn't match a supported version!\n", - " warnings.warn(\"urllib3 ({}) or chardet ({}) doesn't match a supported \"\n", + "/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:630: Checkpoint directory ../checkpoint/infctx-v5-unit-test-baseline-4x1024 exists and is not empty.\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "#\n", "# RWKV lighting_trainer.py important notes \n", @@ -192,14 +195,16 @@ " - lr_init: 8.000e-04 (0.0008)\n", " - lr_final: 4.000e-04 (0.0004)\n", "\n", - "Using /home/harrison/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...\n", + "Using /home/ubuntu/.cache/torch_extensions/py311_cu118 as PyTorch extensions root...\n", "Detected CUDA files, patching ldflags\n", - "Emitting ninja build file /home/harrison/.cache/torch_extensions/py310_cu117/fused_adam/build.ninja...\n", + "Emitting ninja build file /home/ubuntu/.cache/torch_extensions/py311_cu118/fused_adam/build.ninja...\n", "Building extension module fused_adam...\n", "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", "ninja: no work to do.\n", "Loading extension module fused_adam...\n", - "Time to load fused_adam op: 0.04100203514099121 seconds\n", + "Time to load fused_adam op: 0.07130622863769531 seconds\n", + "/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/deepspeed/ops/adam/fused_adam.py:96: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/tensor/python_tensor.cpp:83.)\n", + " self._dummy_overflow_buf = get_accelerator().IntTensor([0])\n", "Loading `train_dataloader` to estimate number of stepping batches.\n", "\n", " | Name | Type | Params\n", @@ -213,21 +218,324 @@ "0 Non-trainable params\n", "87.6 M Total params\n", "350.405 Total estimated model params size (MB)\n", - "Epoch 0: 0%| | 0/763 [00:00\n", + " cli_main()\n", + " File \"/home/ubuntu/rwkv-proj/RWKV-infctx-trainer/RWKV-v5/lightning_trainer.py\", line 271, in cli_main\n", + " LightningCLI(\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/cli.py\", line 386, in __init__\n", + " self._run_subcommand(self.subcommand)\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/cli.py\", line 677, in _run_subcommand\n", + " fn(**fn_kwargs)\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py\", line 545, in fit\n", + " call._call_and_handle_interrupt(\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py\", line 43, in _call_and_handle_interrupt\n", + " return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py\", line 102, in launch\n", + " return function(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py\", line 581, in _fit_impl\n", + " self._run(model, ckpt_path=ckpt_path)\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py\", line 990, in _run\n", + " results = self._run_stage()\n", + " ^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py\", line 1036, in _run_stage\n", + " self.fit_loop.run()\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py\", line 202, in run\n", + " self.advance()\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py\", line 359, in advance\n", + " self.epoch_loop.run(self._data_fetcher)\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py\", line 136, in run\n", + " self.advance(data_fetcher)\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py\", line 240, in advance\n", + " batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py\", line 187, in run\n", + " self._optimizer_step(batch_idx, closure)\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py\", line 265, in _optimizer_step\n", + " call._call_lightning_module_hook(\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py\", line 157, in _call_lightning_module_hook\n", + " output = fn(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/core/module.py\", line 1282, in optimizer_step\n", + " optimizer.step(closure=optimizer_closure)\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/core/optimizer.py\", line 151, in step\n", + " step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/strategies/ddp.py\", line 263, in optimizer_step\n", + " optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py\", line 230, in optimizer_step\n", + " return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/deepspeed.py\", line 123, in optimizer_step\n", + " closure_result = closure()\n", + " ^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py\", line 140, in __call__\n", + " self._result = self.closure(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/torch/utils/_contextlib.py\", line 115, in decorate_context\n", + " return func(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py\", line 126, in closure\n", + " step_output = self._step_fn()\n", + " ^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py\", line 315, in _training_step\n", + " training_step_output = call._call_strategy_hook(trainer, \"training_step\", *kwargs.values())\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py\", line 309, in _call_strategy_hook\n", + " output = fn(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py\", line 381, in training_step\n", + " return self._forward_redirection(self.model, self.lightning_module, \"training_step\", *args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py\", line 628, in __call__\n", + " wrapper_output = wrapper_module(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1518, in _wrapped_call_impl\n", + " return self._call_impl(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1527, in _call_impl\n", + " return forward_call(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/deepspeed/utils/nvtx.py\", line 15, in wrapped_fn\n", + " ret_val = func(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/deepspeed/runtime/engine.py\", line 1814, in forward\n", + " loss = self.module(*inputs, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1518, in _wrapped_call_impl\n", + " return self._call_impl(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1527, in _call_impl\n", + " return forward_call(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py\", line 621, in wrapped_forward\n", + " out = method(*_args, **_kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/rwkv-proj/RWKV-infctx-trainer/RWKV-v5/src/model.py\", line 1150, in training_step\n", + " total_loss = self.compute_loss(batch, batch_idx, True)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/rwkv-proj/RWKV-infctx-trainer/RWKV-v5/src/model.py\", line 1010, in compute_loss\n", + " segment_loss, new_shift_states, new_wkv_states, steps = checkpointed_step(\n", + " ^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ubuntu/rwkv-proj/RWKV-infctx-trainer/RWKV-v5/src/model.py\", line 867, in checkpointed_step\n", + " targets.view(-1),\n", + " ^^^^^^^^^^^^^^^^\n", + "RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.\n" ] } ], "source": [ + "# Short training process - for quick testing / debugging\n", "!cd \"{TRAINER_DIR}\" && \\\n", " export WANDB_MODE=\"{WANDB_MODE}\" && \\\n", " python3 lightning_trainer.py fit \\\n", - " -c \"{NOTEBOOK_DIR}/config/enwiki_10k-world-4096.yaml\" \\\n", - " --trainer.logger.init_args.name=\"{WANDB_PREFIX} (train-ctx=32, data-ctx=4096, {DEEPSPEED_STRAT})\" \\\n", + " -c \"{NOTEBOOK_DIR}/config/enwiki_10k-world-4x1024.yaml\" \\\n", + " --trainer.logger.init_args.name=\"{WANDB_PREFIX} (train-ctx=1024, data-ctx=4096, {DEEPSPEED_STRAT})\" \\\n", + " --trainer.strategy=\"{DEEPSPEED_STRAT}\" \\\n", + " --trainer.devices=\"{GPU_DEVICES}\" \\\n", + " --trainer.fast_dev_run=2 \\\n", + " --model.load_model=\"../model/L6-D512-world-init.pth\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Empty out the checkpoint\n", + "!cd \"{PROJECT_DIR}\" && rm -rf \"./checkpoint/infctx-v5-unit-test-baseline-4x1024/\"" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-11-15 06:57:49,830] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n", + "[RWKV.model] Running RWKV model using 'torch-jit' with torch '2.1.0'\n", + "/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/cli.py:518: LightningCLI's args parameter is intended to run from within Python like if it were from the command line. To prevent mistakes it is not recommended to provide both args and command line arguments, got: sys.argv[1:]=['fit', '-c', '/home/ubuntu/rwkv-proj/RWKV-infctx-trainer/notebook/trainer-v5-unit-test/config/enwiki_10k-world-4x1024.yaml', '--trainer.logger.init_args.name=infctx-v5-unit-test (train-ctx=1024, data-ctx=4096[][-=], deepspeed_stage_1)', '--trainer.strategy=deepspeed_stage_1', '--trainer.devices=auto', '--model.load_model=../model/L6-D512-world-init.pth'], args=['fit', '-c', '/home/ubuntu/rwkv-proj/RWKV-infctx-trainer/notebook/trainer-v5-unit-test/config/enwiki_10k-world-4x1024.yaml', '--trainer.logger.init_args.name=infctx-v5-unit-test (train-ctx=1024, data-ctx=4096[][-=], deepspeed_stage_1)', '--trainer.strategy=deepspeed_stage_1', '--trainer.devices=auto', '--model.load_model=../model/L6-D512-world-init.pth'].\n", + "Seed set to 3941088705\n", + "/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", + " return self.fget.__get__(instance, owner)()\n", + "====================================================================\n", + "[WARNING]: bptt_truncated_learning is set as true (was configured as false), due to incomplete implementation of CUDA kernel for bptt_learning\n", + "====================================================================\n", + "---\n", + "[RWKV.TimeMix] Compiling CUDA kernel with HEAD_SIZE=64\n", + "Using /home/ubuntu/.cache/torch_extensions/py311_cu118 as PyTorch extensions root...\n", + "Detected CUDA files, patching ldflags\n", + "Emitting ninja build file /home/ubuntu/.cache/torch_extensions/py311_cu118/wkv5/build.ninja...\n", + "Building extension module wkv5...\n", + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", + "ninja: no work to do.\n", + "Loading extension module wkv5...\n", + "[RWKV.TimeMix] CUDA kernel compiled & loaded globally\n", + "---\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + "\n", + "[RWKV.Trainer] Applying 'target_batch_size' with the following:\n", + " - target_batch_size: 16\n", + " - num_nodes: 1\n", + " - num_devices: 1\n", + " - accumulate_grad_batches: 16\n", + " - effective_batch_size: 16\n", + "\n", + "/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'.\n", + " table = cls._concat_blocks(blocks, axis=0)\n", + "Saving the dataset (1/1 shards): 100%|█| 751/751 [00:00<00:00, 11376.86 examples\n", + "Saving the dataset (1/1 shards): 100%|███| 8/8 [00:00<00:00, 2405.51 examples/s]\n", + "[rank: 0] Seed set to 3941088705\n", + "initializing deepspeed distributed: GLOBAL_RANK: 0, MEMBER: 1/1\n", + "Enabling DeepSpeed BF16. Model parameters and inputs will be cast to `bfloat16`.\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "#\n", + "# RWKV lighting_trainer.py important notes \n", + "# https://github.com/RWKV/RWKV-infctx-trainer \n", + "#\n", + "# - Ensure your host is not running cuda 12.0 (use either 11.8, or >=12.1), as this is known to have freeze issues\n", + "# - The terms used in wandb / the progress bar can be confusing, see the github README.md for beter clarifications\n", + "# - When resuming from checkpoint, the estimated time is inaccurate\n", + "#\n", + "\n", + "[RWKV.model] Configuring optimizer with\n", + " - lr_init: 8.000e-04 (0.0008)\n", + " - lr_final: 4.000e-04 (0.0004)\n", + "\n", + "Using /home/ubuntu/.cache/torch_extensions/py311_cu118 as PyTorch extensions root...\n", + "Detected CUDA files, patching ldflags\n", + "Emitting ninja build file /home/ubuntu/.cache/torch_extensions/py311_cu118/fused_adam/build.ninja...\n", + "Building extension module fused_adam...\n", + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", + "ninja: no work to do.\n", + "Loading extension module fused_adam...\n", + "Time to load fused_adam op: 0.07390880584716797 seconds\n", + "/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/deepspeed/ops/adam/fused_adam.py:96: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/tensor/python_tensor.cpp:83.)\n", + " self._dummy_overflow_buf = get_accelerator().IntTensor([0])\n", + "Loading `train_dataloader` to estimate number of stepping batches.\n", + "\n", + " | Name | Type | Params\n", + "--------------------------------------\n", + "0 | emb | Embedding | 33.6 M\n", + "1 | blocks | ModuleList | 20.5 M\n", + "2 | ln_out | LayerNorm | 1.0 K \n", + "3 | head | Linear | 33.6 M\n", + "--------------------------------------\n", + "87.6 M Trainable params\n", + "0 Non-trainable params\n", + "87.6 M Total params\n", + "350.405 Total estimated model params size (MB)\n", + "Epoch 0: 21%|▍ | 160/751 [00:35<02:10, 4.53it/s, v_num=yn39, train/loss=8.250]/home/ubuntu/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/torch/nn/modules/module.py:1879: UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n", + " warnings.warn(\n", + "Epoch 0: 100%|██| 751/751 [02:56<00:00, 4.25it/s, v_num=yn39, train/loss=6.810]\n", + "Validation: | | 0/? [00:00