diff --git a/RWKV-v5/src/data.py b/RWKV-v5/src/data.py index 368e918a..b391a2df 100644 --- a/RWKV-v5/src/data.py +++ b/RWKV-v5/src/data.py @@ -46,6 +46,19 @@ def prepare_data_static(**kargs): # ===================================================== + # Util functions + #-------------------------------- + + # Apply the data_prefix_skip_mask to the given mask + # where relevent, and disables the training mask for the first X tokens + data_prefix_skip_mask_val = int(kargs["data_prefix_skip_mask"]) + def apply_data_prefix_skip_mask(mask): + mask_len = len(mask) + if data_prefix_skip_mask_val > 0 and mask_len: + for i in range(max(data_prefix_skip_mask_val, mask_len)): + mask[i] = 0 + return mask + # Special handling for binidx #-------------------------------- @@ -66,7 +79,7 @@ def gen(): yield { 'input_ids': tokens, 'token_type_ids': [0] * len(tokens), - 'attention_mask': [1] * len(tokens) + 'attention_mask': apply_data_prefix_skip_mask([1] * len(tokens)) } # Load the huggingface dataset from the generator @@ -375,7 +388,7 @@ def map_tokenizer(x): return { 'input_ids': input_ids, 'token_type_ids': token_type_ids, - 'attention_mask': attention_mask + 'attention_mask': apply_data_prefix_skip_mask(attention_mask) } # Multi column merging support @@ -443,7 +456,7 @@ def map_tokenizer(x): return { 'input_ids': input_ids, 'token_type_ids': token_type_ids, - 'attention_mask': attention_mask + 'attention_mask': apply_data_prefix_skip_mask(attention_mask) } # Prompt completion support @@ -472,12 +485,17 @@ def map_tokenizer(x): return { 'input_ids': input_ids, 'token_type_ids': token_type_ids, - 'attention_mask': attention_mask, + 'attention_mask': apply_data_prefix_skip_mask(attention_mask), } # Fallback to standard text tokenization if 'text' in x: - return encodeTokens(x['text']) + ret = encodeTokens(x['text']) + return { + 'input_ids': ret['input_ids'], + 'token_type_ids': ret['token_type_ids'], + 'attention_mask': apply_data_prefix_skip_mask(ret['attention_mask']), + } raise ValueError('Invalid dataset format, must contain either the configured "multi column" or prompt/completion or text') @@ -519,7 +537,7 @@ def rechunk_text(x): # with the newline token in between full_input_ids += x["input_ids"][i] + endOfDoc_tokenSet["input_ids"][0] full_token_type_ids += x["token_type_ids"][i] + endOfDoc_tokenSet["token_type_ids"][0] - full_attention_mask += x["attention_mask"][i] + endOfDoc_tokenSet["attention_mask"][0] + full_attention_mask += apply_data_prefix_skip_mask( x["attention_mask"][i] ) + endOfDoc_tokenSet["attention_mask"][0] # Total length, and sample count # note that thte "remainder" will be discarded @@ -540,7 +558,7 @@ def rechunk_text(x): # Push the sample to the output arrays out_input_ids.append(full_input_ids[start:end]) out_token_type_ids.append(full_token_type_ids[start:end]) - out_attention_mask.append(full_attention_mask[start:end]) + out_attention_mask.append(apply_data_prefix_skip_mask( full_attention_mask[start:end] )) # Prepare and return the output object ret = { @@ -565,6 +583,8 @@ def dataset_filter(x): return False if kargs["max_token_size"] > 0 and row_length > kargs["max_token_size"]: return False + if sum(x["attention_mask"]) <= 0: + return False return True src_dataset = src_dataset.filter(dataset_filter, num_proc=num_cpus) @@ -902,6 +922,18 @@ def __init__( # prompt/completion format masking support disable_prompt_completion_mask: bool = False, + # ---------------------------- + # Selective loss training + # ---------------------------- + + # Prefix token masking + # + # The rationale behind this, is that the first X tokens should not be "backpropped" + # for any new training record. As its unfair to expect the model (or a human) make + # any resonable guesses at that stage. As such this is used to "mask" the first X tokens + # from the loss calculation, and thus not backpropped. + data_prefix_skip_mask: int = 0, + # ---------------------------- # dataset packing support # ---------------------------- @@ -1022,4 +1054,4 @@ def val_dataloader(self): batch_size=1, # Pinned in GPU memory pin_memory=True - ) + ) \ No newline at end of file diff --git a/RWKV-v5/src/model.py b/RWKV-v5/src/model.py index 86cb7025..912a5684 100644 --- a/RWKV-v5/src/model.py +++ b/RWKV-v5/src/model.py @@ -126,7 +126,7 @@ def forward(self, x, last_state: BlockState): class L2Wrap(torch.autograd.Function): @staticmethod - def forward(ctx, loss, y, token_amount, currentMask): + def forward(ctx, loss, y, factor, currentMask): # Currently (8th July 2023), save_for_backward, causes an issue with # pytorch.compile (see: https://github.com/pytorch/pytorch/blob/e600505e3209eaf539e8bc99870ea55236cefbf5/torch/_dynamo/variables/higher_order_ops.py#L735) # @@ -135,15 +135,13 @@ def forward(ctx, loss, y, token_amount, currentMask): # # See also: # - checkpointed_step - ctx.save_for_backward(y, token_amount, currentMask) + ctx.save_for_backward(y, factor, currentMask) return loss @staticmethod def backward(ctx, grad_output): - y, token_amount, currentMask = ctx.saved_tensors + y, factor, currentMask = ctx.saved_tensors - # to encourage the logits to be close to 0 - factor = 1e-4 / token_amount maxx, ids = torch.max(y, -1, keepdim=True) gy = torch.zeros_like(y) gy.scatter_(-1, ids, maxx * factor) @@ -193,9 +191,15 @@ def __init__(self, adam_eps: float = 1.0e-08, weight_decay: float = 0.01, warmup_steps: int = -1, + # loss bias start position_loss_bias: float = 1.0, position_loss_bias_in_validation: bool = False, + + # Selective loss settings + token_loss_threshold: float = 0.0, + token_dropout_rate: float = 0.0, # Dropout rate should be between 0-1 + # Backprop settings grad_cp: bool = True, bptt_learning: bool = True, @@ -289,9 +293,11 @@ def __init__(self, print("====================================================================") self.bptt_truncated_learning = True - # Save the position loss params + # Save the position loss params, and selective loss settings self.position_loss_bias = position_loss_bias self.position_loss_bias_in_validation = position_loss_bias_in_validation + self.token_loss_threshold = token_loss_threshold + self.token_dropout_rate = token_dropout_rate dim_att = dim_att or n_embd dim_ffn = dim_ffn or int((n_embd * 3.5) // 32 * 32) @@ -803,34 +809,37 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool): # should not be allowed num_devices = self.trainer.num_devices - ### --- - ### Positional loss bias handling - ### --- + # ### --- + # ### Positional loss bias handling + # ### --- - # Get the starting and ending loss bias - loss_bias_start = self.position_loss_bias - loss_bias_end = 2.0 - loss_bias_start + # # Get the starting and ending loss bias + # loss_bias_start = self.position_loss_bias + # loss_bias_end = 2.0 - loss_bias_start - # Skip loss bias calculation, if loss_bias_start is 1.0 - if loss_bias_start == 1.0 or (is_training_run == False and self.position_loss_bias_in_validation == False): - seq_mask = ori_seq_mask - else: - # Lets get the torch mask sum - total_mask_sum = torch.sum(ori_seq_mask) + # # Skip loss bias calculation, if loss_bias_start is 1.0 + # if loss_bias_start == 1.0 or (is_training_run == False and self.position_loss_bias_in_validation == False): + # seq_mask = ori_seq_mask + # else: + # # Lets get the torch mask sum + # total_mask_sum = torch.sum(ori_seq_mask) - # Lets get a linear multiplier for the loss bias - # seq_mask_sum = torch.sum(ori_seq_mask) - bias_mask = torch.linspace(loss_bias_start, loss_bias_end, int(total_mask_sum.item()), device=ori_seq_mask.device) + # # Lets get a linear multiplier for the loss bias + # # seq_mask_sum = torch.sum(ori_seq_mask) + # bias_mask = torch.linspace(loss_bias_start, loss_bias_end, int(total_mask_sum.item()), device=ori_seq_mask.device) - # Boolean flag of seq_mask > 0 - seq_mask_index = ori_seq_mask[0] > 0 + # # Boolean flag of seq_mask > 0 + # seq_mask_index = ori_seq_mask[0] > 0 - # Apply the bias mask only to positive seq_mask values - final_mask = torch.zeros(ori_seq_mask.shape[1], device=ori_seq_mask.device) - final_mask[seq_mask_index] = ori_seq_mask[0][seq_mask_index] * bias_mask + # # Apply the bias mask only to positive seq_mask values + # final_mask = torch.zeros(ori_seq_mask.shape[1], device=ori_seq_mask.device) + # final_mask[seq_mask_index] = ori_seq_mask[0][seq_mask_index] * bias_mask - # And save it as seq_mask - seq_mask = final_mask.unsqueeze(0) + # # And save it as seq_mask + # seq_mask = final_mask.unsqueeze(0) + + # Since we are no longer doing positional loss above, use seq_mask directly + seq_mask = ori_seq_mask ### --- ### Training cutoff logic handling @@ -884,8 +893,8 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool): return 0 # Checkpoint steps - def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states, - last_wkv_states, prev_steps): + def checkpointed_step(idx, targets, mask, last_shift_states, + last_wkv_states): logits, new_shift_states, new_wkv_states = self( idx, last_shift_states, last_wkv_states) @@ -895,26 +904,80 @@ def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states, targets = targets.contiguous() mask = mask.contiguous() - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), + # Compute the token loss + token_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction="none") + submask = mask.view(-1)[:token_loss.shape[0]] + + # to encourage the logits to be close to 0 + # factor_divisor is typically the total token count + L2Wrap_factor = 1e-4 / total_mask_sum - submask = mask.view(-1)[:loss.shape[0]] - submask_sum = torch.sum(submask) - loss = torch.sum(loss * submask) / total_mask_sum + # Submask count + submask_count = torch.sum(submask) + + # Selective token loss logic + if submask_count <= 0.0: + train_loss = torch.tensor(0, dtype=self.emb.weight.dtype).requires_grad_() + sample_loss = train_loss.clone().detach().requires_grad_(False) + train_token_count = 0 + train_mask = submask + + elif self.token_loss_threshold > 0.0 or self.token_dropout_rate > 0.0: + + # Sample loss, without backprop + with torch.no_grad(): + sample_loss = (torch.sum(token_loss * submask) / total_mask_sum).clone().detach().requires_grad_(False) + + # Building the training mask + train_mask = submask + + # Selective loss gating + if self.token_loss_threshold > 0.0: + above_threshold = token_loss > self.token_loss_threshold + train_mask = train_mask * above_threshold + + # Dropout logic + if self.token_dropout_rate > 0.0: + dropout_mask = torch.rand(train_mask.shape, device=train_mask.device) > self.token_dropout_rate + train_mask = train_mask * dropout_mask + + # The training loss to use + train_loss = torch.sum(token_loss * train_mask) / total_mask_sum + train_token_count = torch.sum(train_mask) - loss = L2Wrap.apply(loss, logits, total_mask_sum, submask) - new_steps = prev_steps + submask_sum - new_loss = prev_loss + loss - return new_loss, new_shift_states, new_wkv_states, new_steps + # Adjust the factor accordingly + L2Wrap_factor = L2Wrap_factor * (submask_count / train_token_count) - total_loss = torch.tensor(0, dtype=self.emb.weight.dtype).requires_grad_() - steps = 0 + else: + train_loss = torch.sum(token_loss * submask) / total_mask_sum + sample_loss = train_loss.clone().detach().requires_grad_(False) + train_token_count = submask_count + train_mask = submask + + if train_loss <= 0.0: + segment_train_loss = torch.tensor(0, dtype=self.emb.weight.dtype).requires_grad_() + else: + # L2Wrap for the backprop process + segment_train_loss = L2Wrap.apply(train_loss, logits, L2Wrap_factor, train_mask) + + # Return the checkpoint values + return sample_loss, segment_train_loss, new_shift_states, new_wkv_states, train_token_count + + # Initialize the states, and compute the segment count states = BlockStateList.create(self.n_layer, B, C, self.n_head, self.head_size, seq.device, self.emb.weight.dtype) segment_count = math.ceil(T / self.ctx_len) + # Initialize the training loss, and the token count + training_loss = torch.tensor(0, dtype=self.emb.weight.dtype).requires_grad_() + training_tokens = 0 + + # Raw sample loss (before selective token training) + sampling_loss = 0 + ### --- ### Learning process logic (BPTT or not) ### --- @@ -1052,14 +1115,12 @@ def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states, cur_msk = dummy_2d_zero # Segmented learning, applies the forward/pass over each chunk seperately - segment_loss, new_shift_states, new_wkv_states, steps = checkpointed_step( + segment_sample_loss, segment_train_loss, new_shift_states, new_wkv_states, segment_train_tokens = checkpointed_step( cur_idx, cur_tar, cur_msk, - torch.tensor(0, dtype=self.emb.weight.dtype, device=cur_device).requires_grad_(True), prv_shift_states, - prv_wkv_states, - steps, + prv_wkv_states ) states = BlockStateList(new_shift_states, new_wkv_states) @@ -1067,93 +1128,65 @@ def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states, # segment_loss_arr[i] = segment_loss # Perform the backward pass accordingly, for valid segments (besides the last segment) - # In this version, we do backward passes together the forward passes in the main segment loop + # In this version, we do backward passes together with the forward passes in the main segment loop # Instead of after all segment losses are computed + # + # In the past, we have implemented to do all forward, and all backwards. But this was found to be "slow" if i >= start_learning_segment and i < start_learning_segment + backward_segment_count: # The learning loss, should be normalized against the accumulation steps # as we are bypassing the pytorch lightning normalization # https://lightning.ai/docs/pytorch/2.0.4/common/lightning_module.html#backward - learning_loss = segment_loss / gradient_accumulation_steps + learning_loss = segment_train_loss / gradient_accumulation_steps - # Perform the backward pass accordingly, for valid segments (besides the last segment) - if i == start_learning_segment + backward_segment_count - 1: - # This is the last backward pass, we let the default pytorch lightning handle the backward pass - # and return the segment loss as part of the total loss - total_loss = total_loss + segment_loss - else: - # Undocumented multiple backward pass support - # https://github.com/Lightning-AI/lightning/blob/678f642808c54e4c490caee4df5d357301c976bb/tests/trainer/optimization/test_manual_optimization.py#L251 - self.manual_backward(learning_loss, optimizer, retain_graph=True) - - # Accumulate without gradient, as we already did the backward pass - total_loss = total_loss + segment_loss.clone().detach().requires_grad_(False) + # Undocumented multiple backward pass support + # https://github.com/Lightning-AI/lightning/blob/678f642808c54e4c490caee4df5d357301c976bb/tests/trainer/optimization/test_manual_optimization.py#L251 + self.manual_backward(learning_loss, optimizer, retain_graph=True) + + # Accumulate without gradient, as we already did the backward pass + # This does mean, that a single backward pass is "wasted" at the end + training_loss = training_loss + segment_train_loss.clone().detach().requires_grad_(False) else: # Even if its not the segments we use for backward pass, we still need to accumulate the loss - total_loss = total_loss + segment_loss.clone().detach().requires_grad_(False) + training_loss = training_loss + segment_train_loss.clone().detach().requires_grad_(False) + # Add token count and raw sampling loss + training_tokens = training_tokens + segment_train_tokens + sampling_loss = sampling_loss + segment_sample_loss + # GC collect unused memory # gc.collect() # torch.cuda.empty_cache() - - # # Lets backpass the respective segments, in reverse - # # (including dummy backpass) - # for i in range(forward_segment_count-1, -1, -1): - # # Get the segment loss - # segment_loss = segment_loss_arr[i] - # - # # Compute the backward pass for the segment - # if i >= start_learning_segment and i < start_learning_segment + backward_segment_count: - # # The learning loss, should be normalized against the accumulation steps - # # as we are bypassing the pytorch lightning normalization - # # https://lightning.ai/docs/pytorch/2.0.4/common/lightning_module.html#backward - # learning_loss = segment_loss / gradient_accumulation_steps - # - # # Perform the backward pass accordingly, for valid segments (besides the start_learning_segment) - # if i > start_learning_segment: - # # Undocumented multiple backward pass support - # # https://github.com/Lightning-AI/lightning/blob/678f642808c54e4c490caee4df5d357301c976bb/tests/trainer/optimization/test_manual_optimization.py#L251 - # self.manual_backward(learning_loss, optimizer, retain_graph=True) - # - # # Accumulate without gradient, as we already did the backward pass - # total_loss = total_loss + segment_loss.clone().detach().requires_grad_(False) - # else: - # # This is the last backward pass, we let the default pytorch lightning handle the backward pass - # # and return the segment loss as part of the total loss - # total_loss = total_loss + segment_loss - # else: - # # Even if its not the segments we use for backward pass, we still need to accumulate the loss - # total_loss = total_loss + segment_loss.clone().detach().requires_grad_(False) - # - # # GC collect unused memory - # gc.collect() - # # torch.cuda.empty_cache() else: + # # Normal operations without BPTT + # segment_size = self.ctx_len for i in range(segment_count): if i < segment_count-1 and is_training_run: - total_loss, new_shift_states, new_wkv_states, steps = deepspeed_checkpoint( + segment_sample_loss, segment_train_loss, new_shift_states, new_wkv_states, segment_train_tokens = deepspeed_checkpoint( checkpointed_step, idx[:, i * segment_size:(i + 1) * segment_size], targets[:, i * segment_size:(i + 1) * segment_size], seq_mask[:, i * segment_size:(i + 1) * segment_size], - total_loss, states.shift_states, - states.wkv_states, - steps, + states.wkv_states ) else: - total_loss, new_shift_states, new_wkv_states, steps = checkpointed_step( + segment_sample_loss, segment_train_loss, new_shift_states, new_wkv_states, segment_train_tokens = checkpointed_step( idx[:, i * segment_size:(i + 1) * segment_size], targets[:, i * segment_size:(i + 1) * segment_size], seq_mask[:, i * segment_size:(i + 1) * segment_size], - total_loss, states.shift_states, - states.wkv_states, - steps, + states.wkv_states ) + + # Add them up + training_loss = training_loss + segment_train_loss + training_tokens = training_tokens + segment_train_tokens + sampling_loss = sampling_loss + segment_sample_loss + # Update the states states = BlockStateList(new_shift_states, new_wkv_states) gc.collect() # torch.cuda.empty_cache() @@ -1162,24 +1195,34 @@ def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states, if wandb.run is not None and is_training_run: global_rank = self.global_rank global_device_count = self.trainer.num_devices * self.trainer.num_nodes + microbatch_size = self.trainer.microbatch_size # Get the total dataset context length batch_ctx_len = 0 if "data_ctx_len" in batch: batch_ctx_len = torch.sum(batch["data_ctx_len"]).item() else: - batch_ctx_len = T * self.trainer.microbatch_size + batch_ctx_len = T * microbatch_size # Increment the counting tokens, and log it accordingly self._counting_tokens += batch_ctx_len # Log the line values wandb.log({ - 'global_rank': global_rank, - 'data_ctx_len': batch_ctx_len / self.trainer.microbatch_size, - 'train/loss': total_loss, + # The original loss and ctx_len (averaged by batch size) + 'train/ctx_len': batch_ctx_len / microbatch_size, + 'train/data_loss': sampling_loss, + + # The selective training tokens, and loss + 'train/tokens': training_tokens / microbatch_size, + 'train/loss': training_loss, + + # Perf tracking f'perf/tokens_total.gpu.{global_rank}': self._counting_tokens, f'perf/tokens_per_sec.gpu.{global_rank}': self._counting_tokens / max(time.time() - self._counting_time_start, 1), + + # Step and trainer tracking + 'global_rank': global_rank, 'substep': (batch_idx * global_device_count + global_rank), 'trainer/global_step':self.global_step, 'trainer/learning_rate': self.trainer.optimizers[0].param_groups[0]['lr'], @@ -1187,8 +1230,8 @@ def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states, }) # Throw if total loss is NaN - assert not torch.isnan(total_loss), "total_loss is NaN" - return total_loss + assert not torch.isnan(training_loss), "training_loss is NaN" + return training_loss # # Training and validation steps diff --git a/notebook/trainer-v5-validation/config/enwiki_10k-world-full.yaml b/notebook/trainer-v5-validation/config/enwiki_10k-world-full.yaml new file mode 100644 index 00000000..85b60f6a --- /dev/null +++ b/notebook/trainer-v5-validation/config/enwiki_10k-world-full.yaml @@ -0,0 +1,265 @@ +# lightning.pytorch==2.0.2 +seed_everything: 3941088705 +trainer: + + # + # Configure the deepspeed strategy, we recommend you start with `deepspeed_stage_2_offload` + # and adjust from there according to your training needs. `deepspeed_stage_3_offload` is useful + # for training LoRA on large models on a single GPU. + # + # In general you would want to use the following: + # + # - deepspeed_stage_1 : Each of your GPU has too much vram, and you do not know what to do + # + # - deepspeed_stage_2 : Optimal distributed training strategy, across multiple gpu each with sufficient vram + # - deepspeed_stage_2_offload : Reduce vram usage by offloading the optimizer state and work to cpu + # + # - deepspeed_stage_3 : Split up the model across multiple gpu, useful for large models, at a performance cost + # - deepspeed_stage_3_offload : Additional offloading, for even greater performance cost + # + # For more details see: + # https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed-zero-stage-2 + # + strategy: deepspeed_stage_2_offload + + # Logger setting for wandb, if you want to enable wandb, uncomment the whole logger section + # --- + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + name: 'infctx-v5-unit-test-baseline (train-ctx=4096, data-ctx=full)' + project: 'RWKV-infctx-unit-test' + tags: ['RWKV', 'infctx'] + + # Checkpoint settings for the training process + callbacks: + class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + # Configure this to the path you want to save your checkpoints to + # note that a subdir will be created with the name `epoch=x-step=y.ckpt` + # + # to convert a checkpoint to a model, you can use the + # `python3 export_checkpoint.py ` script, + # which will create a `rwkv_model.pth` in the checkpoint directory. + # + # Do not use the `zero_to_fp32.py` script as that will have export format issues + dirpath: ../checkpoint/trainer-validaiton/infctx-v5-enwiki-10k-full + filename: null + + # Save the top/last K checkpoints + save_top_k: 3 + # Choose by the most recent checkpoints (step based) + monitor: 'step' + mode: max + + # If enabled (true), save a copy of the latest checkpoint to 'last.ckpt' + # useful to simply checkpoint resume scripts, at a price of disk performance + save_last: false + + # DO NOT set this as true, as the model weight exported will have format issues + # expert as checkpoint, and use the `export_checkpoint.py` script to convert to model instead + save_weights_only: false + + # How frequent you want to save a checkpoint for every step. + # This will happen for every X data sample, where X = every_n_train_steps * accumulate_grad_batches + # + # In general you will want to avoid putting a low number (expecially if accumulate_grad_batches <= 100) + # as the checkpoint process, will pause all the gpu training for some time, slowing down the overall process + # However you do not want to configure too high of a number, where you will lose too much progress if the training crashes + every_n_train_steps: 100 + every_n_epochs: null + save_on_train_epoch_end: true + train_time_interval: null + + # Other settings, you can probably leave alone + verbose: false + auto_insert_metric_name: true + + ######################################## + ## Training run parameter settings + ######################################## + + # Generally what you want to configure is the maximum number of epochs + # Leave it as -1, and it will keep going forever till interrupted + # Or set it as a number, and it will stop after that number of epochs + max_epochs: 1 + min_epochs: null + max_steps: -1 + min_steps: null + max_time: null + + # Number of datasamples to train for each step, a data sample is considered + # a "substep" in wandb logs, and a "step" is tracked as "trainer/global_step" + # + # This decides the number of datasample, to learn together from, before backproping + # any weight changes at the end of the batch. + # + # Recommended to be a big enough number (like 128/256) where it prevents the training + # loss from flucuating in the process. But not too big of a number where the increased + # GPU vRAM / offloaded RAM usage will cause the training to crash. + # + # You are also recommended to configure this to a large enough number to fully utilize + # your GPU processing time %, and avoid idle time for the GPU between batches + # + # This number is divided by the number of GPUs, and nodes configured + # So if you have 4 GPUs, and 2 nodes, and this is configured as 128 + # Each GPU will process 128/4/2 = 16 datasamples per step, via accumulate_grad_batches + target_batch_size: 16 + +######################################## +## Training model settings +######################################## +model: + # Model to start the finetune/training process from + load_model: ../model/L24-D2048-world-v5base-init.pth + + # Context length to use for the training process + # the larger the number (and batch size) the larger the vram usage + # + # Note that if the datasample context length is larger then the ctx_len + # its training process would be split into ctx_len sized chunks. + # + # This allows the training of extreamly large context length (eg. 100k), + # without eating up too much vram by keeping the training context length + # to a resonable number sutible to the current GPU setup + ctx_len: 4096 + + # Data samples would be cut down to the respective max ctx_len_cutoffs + # values if its larger then ctx_len. If the data sample is larger then + # the largest len_cutoff, the remaining data will be discarded + ctx_len_cutoffs: [] + # Experimental settings, number of tokens to skip in the data sample + # prefix, for the respective cutoff length. Used to speed up the process + ctx_len_warmup_steps: [] + + # Learning rate of the training process + # --- + + # Initia learning rate of the process + lr_init: 8e-4 + # Final learning rate after the learning rate period + # learning rate will stay at final value from then onwards + lr_final: 4e-4 + + # Number of epoch to reduce the learning rate from lr_init to lr_final + # 1 means a single epoch (so lr would be lr_final from epoch 2 onwards) + # 0 means lr_final will apply immediately + # -1 means we take the current max_step / max_epoch as the period + lr_period: 1 + # lr_period type if its set, defaults to epoch + lr_period_type: epoch + + # Adam optimizer settings + # You probably want to leave this alone, unless you know what you are doing + beta1: 0.9 + beta2: 0.99 + adam_eps: 1.0e-08 + weight_decay: 0.01 + + # torch.set_float32_matmul_precision, used to optimize operations with tensor cores + # this should be set as null, for non cuda core GPUs + torch_set_float32_matmul_precision: 'high' + # torch_set_float32_matmul_precision: null + + # Segmented based learning, used to work around training of large context length + # beyond what can be supported by the current GPU vram architecture + # + # This is not 1:1 equivalent to the same training process with required vram + # as the training process is split into multiple segments, part by part. + # with limited learnings from the previous segment. + bptt_learning: true + + # Segmented range to performing backprop learning on + # 1 means to apply only for the last segment + # -1 means to apply for all segments + bptt_learning_range: -1 + +data: + # dataset_path for the prebuilt dataset, using HF `load_from_disk()` + # + # Use this if you have built your own dataset and saved it with `save_to_disk()` + # with source left as null. Other wise configure this to a directory which the + # dataset will be built and tokenized by the huggingface dataset process. + data_path: ../datapath/enwiki_10k-world-4096/ + + # Other wise provide the source path, which is used as huggingface dataset path + # this will be used to populate the dataset_path + # + # Use either the following + # - hugging face dataset + # - Directory path to a directory containing dataset files + # - Path to a single dataset file + # - hugging face dataset mode (ie: text,csv,etc - use data_dir, to configure the path then) + # - null + # + # If source is disabled, all other params, except data_path, is ignored + source: "teven/enwiki_10k" + # source: text + # source: /home/ubuntu/RWKV-LM-LoRA/dataset-text/enwik8.txt + + # Use data_dir, if you are using source=text/json/etc + # this should be relative to the trainer script path + source_data_dir: null + + # After loading the dataset, split out test data used for unit-test, + # This process is skipped if the dataset includes a test split + # This process is skipped if set to zero + test_split: 0.01 + test_split_shuffle: false + + # Tokenizer to use, use either the inbuilt 'neox', or 'world' tokenizer + # If using a custom tokenizer, provide the tokenizer file path + # --- + tokenizer: world + + # Minimum / Maximum token size of the dataset to use + # useful for filtering out small noisy data samples from large datasets + # (eg. removal of small articles of less then 512 tokens from wikipedia) + # + # This is ignored, if set to -1 + min_token_size: -1 + max_token_size: -1 + + # Rechunking of text dataset, this is done only when source is set as 'text' + # and will merge the various sentencees, into larger chunks up to the target size + # + # Defaults to 4096 + # + # This is ignored, if source is not set as text + # This is ignored, if set to zero + # --- + # text_rechunk_size: 4096 + + # Apply text rechunk to the dataset, even if its not a 'text' source + # This is done only after dataset filtering, and if source is not 'text' + # --- + text_rechunk_force: false + + # Custom text column to use, useful for dataset with alternative training columns labels + # This is checked before multi column merging, default is null (disabled) + # eg: 'code' + # --- + # custom_text_key: 'code' + + # Multi Column merging process, default setting is used to support and merge + # "instruction", "input", "output", datasets. To disable set multi_column_keys to [] + # + # A minimum of 2 columns is required, with non empty data, for the merge to occur + # If no match is found, this will fallback to the default prompt/completion or text column, + # or throw an error if the default fallback is not found + # --- + # multi_column_keys: ['instruction', 'input', 'output'] + # multi_column_prefix: ['Instruction:\n', 'Input:\n', 'Output:\n'] + # multi_column_train_mask: [true, false, true] + # multi_column_separator: '\n\n' + + # If processing prompt/completion jsonl pairs, the prompt is masked by default + # use this flag to disable this default behaviour + # --- + # disable_prompt_completion_mask: false + +# Path to the current checkpoint to continue training from +# Enable this to the last checkpoint after the first run +# (if it crash and you want to resume) +# ckpt_path: ../checkpoint/trainer-validaiton/infctx-unit-test-baseline/epoch=0-step=20.ckpt +ckpt_path: null diff --git a/notebook/trainer-v5-validation/dataset-microbatch.ipynb b/notebook/trainer-v5-validation/dataset-microbatch.ipynb new file mode 100644 index 00000000..cf45e52c --- /dev/null +++ b/notebook/trainer-v5-validation/dataset-microbatch.ipynb @@ -0,0 +1,775 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Dataset microbatch testing\n", + "\n", + "Testing runs on multiple micro batch settings" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ENABLE_WANDB: False\n", + "GPU_DEVICES: auto\n", + "NOTEBOOK_DIR: /home/picocreator/rwkv-proj/RWKV-infctx-trainer/notebook/trainer-v5-validation\n", + "TRAINER_DIR: /home/picocreator/rwkv-proj/RWKV-infctx-trainer/RWKV-v5\n", + "PROJECT_DIR: /home/picocreator/rwkv-proj/RWKV-infctx-trainer\n" + ] + } + ], + "source": [ + "GPU_DEVICES=\"auto\"\n", + "ENABLE_WANDB=False\n", + "WANDB_PREFIX=\"infctx-v5-microbatch\"\n", + "\n", + "print(\"ENABLE_WANDB:\", ENABLE_WANDB)\n", + "print(\"GPU_DEVICES:\", GPU_DEVICES)\n", + "\n", + "if ENABLE_WANDB:\n", + " WANDB_MODE=\"online\"\n", + "else:\n", + " WANDB_MODE=\"disabled\"\n", + "\n", + "# Computing the notebook, and various paths\n", + "import os\n", + "NOTEBOOK_DIR=os.path.dirname(os.path.abspath(\"__file__\"))\n", + "PROJECT_DIR=os.path.abspath(os.path.join(NOTEBOOK_DIR, \"../../\"))\n", + "TRAINER_DIR=os.path.abspath(os.path.join(PROJECT_DIR, \"./RWKV-v5/\"))\n", + "\n", + "print(\"NOTEBOOK_DIR:\", NOTEBOOK_DIR)\n", + "print(\"TRAINER_DIR:\", TRAINER_DIR)\n", + "print(\"PROJECT_DIR:\", PROJECT_DIR)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2024-01-18 11:19:39,010] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n", + "[RWKV.model] Running RWKV model using 'torch-jit' with torch '2.1.1'\n", + "---- Initializing model ----\n", + "No of layers: 6\n", + "Embedding size: 512\n", + "Output model path: ../model/L6-D512-world-v5base-init.pth\n", + "Vocab size: 65536\n", + "Emb scale: 0.0001\n", + "Note: this process takes a significant time (and ram) for large models\n", + "---- ----- ----\n", + "Model exists, skipping init_model\n" + ] + } + ], + "source": [ + "# Init the model\n", + "!cd \"{TRAINER_DIR}\" && \\\n", + " python3 ./init_model.py \\\n", + " --n_layer 6 --n_embd 512 \\\n", + " --vocab_size world --skip-if-exists \\\n", + " \"../model/L6-D512-world-v5base-init.pth\"" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Map (num_proc=16): 100%|█████████| 10000/10000 [00:01<00:00, 9575.14 examples/s]\n", + "Filter (num_proc=16): 100%|█████| 10000/10000 [00:00<00:00, 12203.75 examples/s]\n", + "Map (num_proc=16): 100%|██████████| 9892/9892 [00:00<00:00, 20646.21 examples/s]\n", + "Saving the dataset (1/1 shards): 100%|█| 9892/9892 [00:00<00:00, 241357.37 examp\n", + "Saving the dataset (1/1 shards): 100%|█| 100/100 [00:00<00:00, 28064.93 examples\n" + ] + } + ], + "source": [ + "# Lets preload the requried dataset \n", + "!cd \"{TRAINER_DIR}\" && \\\n", + " python3 preload_datapath.py \"{NOTEBOOK_DIR}/config/enwiki_10k-world-full.yaml\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# microbatch=1\n", + "\n", + "Note: We are intentionally testing without rechunk, as that has known edge case issues." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2024-01-18 12:00:55,830] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n", + "[RWKV.model] Running RWKV model using 'torch-jit' with torch '2.1.1'\n", + "/home/picocreator/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/cli.py:518: LightningCLI's args parameter is intended to run from within Python like if it were from the command line. To prevent mistakes it is not recommended to provide both args and command line arguments, got: sys.argv[1:]=['fit', '-c', '/home/picocreator/rwkv-proj/RWKV-infctx-trainer/notebook/trainer-v5-validation/config/enwiki_10k-world-full.yaml', '--model.load_model=../model/L6-D512-world-v5base-init.pth', '--trainer.callbacks.init_args.dirpath=../checkpoint/v5-enwiki-10k-full/', '--trainer.logger.init_args.name=infctx-v5-microbatch - Microbatch 1 - (deepspeed_stage_1)', '--trainer.strategy=deepspeed_stage_1', '--trainer.microbatch_size=1', '--trainer.devices=auto'], args=['fit', '-c', '/home/picocreator/rwkv-proj/RWKV-infctx-trainer/notebook/trainer-v5-validation/config/enwiki_10k-world-full.yaml', '--model.load_model=../model/L6-D512-world-v5base-init.pth', '--trainer.callbacks.init_args.dirpath=../checkpoint/v5-enwiki-10k-full/', '--trainer.logger.init_args.name=infctx-v5-microbatch - Microbatch 1 - (deepspeed_stage_1)', '--trainer.strategy=deepspeed_stage_1', '--trainer.microbatch_size=1', '--trainer.devices=auto'].\n", + "Seed set to 3941088705\n", + "/home/picocreator/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", + " return self.fget.__get__(instance, owner)()\n", + "---\n", + "[RWKV.TimeMix] Compiling CUDA kernel with HEAD_SIZE=64\n", + "Using /home/picocreator/.cache/torch_extensions/py311_cu121 as PyTorch extensions root...\n", + "Detected CUDA files, patching ldflags\n", + "Emitting ninja build file /home/picocreator/.cache/torch_extensions/py311_cu121/wkv5/build.ninja...\n", + "Building extension module wkv5...\n", + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", + "ninja: no work to do.\n", + "Loading extension module wkv5...\n", + "[RWKV.TimeMix] CUDA kernel compiled & loaded globally\n", + "---\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + "\n", + "[RWKV.Trainer] Applying 'target_batch_size' with the following:\n", + " - target_batch_size: 16\n", + " - num_nodes: 1\n", + " - num_devices: 1\n", + " - microbatch_size: 1\n", + " - accumulate_grad_batches: 16\n", + " - effective_batch_size: 16\n", + "\n", + "Saving the dataset (1/1 shards): 100%|█| 9892/9892 [00:00<00:00, 595479.80 examp\n", + "Saving the dataset (1/1 shards): 100%|█| 100/100 [00:00<00:00, 28472.64 examples\n", + "[rank: 0] Seed set to 3941088705\n", + "initializing deepspeed distributed: GLOBAL_RANK: 0, MEMBER: 1/1\n", + "Enabling DeepSpeed BF16. Model parameters and inputs will be cast to `bfloat16`.\n", + "/home/picocreator/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:639: Checkpoint directory ../checkpoint/v5-enwiki-10k-full/ exists and is not empty.\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "#\n", + "# RWKV lighting_trainer.py important notes \n", + "# https://github.com/RWKV/RWKV-infctx-trainer \n", + "#\n", + "# - Ensure your host is not running cuda 12.0 (use either 11.8, or >=12.1), as this is known to have freeze issues\n", + "# - The terms used in wandb / the progress bar can be confusing, see the github README.md for beter clarifications\n", + "# - When resuming from checkpoint, the estimated time is inaccurate\n", + "#\n", + "\n", + "[RWKV.model] Configuring optimizer with\n", + " - lr_init: 8.000e-04 (0.0008)\n", + " - lr_final: 4.000e-04 (0.0004)\n", + "\n", + "Using /home/picocreator/.cache/torch_extensions/py311_cu121 as PyTorch extensions root...\n", + "Detected CUDA files, patching ldflags\n", + "Emitting ninja build file /home/picocreator/.cache/torch_extensions/py311_cu121/fused_adam/build.ninja...\n", + "Building extension module fused_adam...\n", + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", + "ninja: no work to do.\n", + "Loading extension module fused_adam...\n", + "Time to load fused_adam op: 0.05255270004272461 seconds\n", + "/home/picocreator/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/deepspeed/ops/adam/fused_adam.py:96: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/conda/conda-bld/pytorch_1699449201450/work/torch/csrc/tensor/python_tensor.cpp:83.)\n", + " self._dummy_overflow_buf = get_accelerator().IntTensor([0])\n", + "Loading `train_dataloader` to estimate number of stepping batches.\n", + "\n", + " | Name | Type | Params\n", + "--------------------------------------\n", + "0 | emb | Embedding | 33.6 M\n", + "1 | blocks | ModuleList | 20.5 M\n", + "2 | ln_out | LayerNorm | 1.0 K \n", + "3 | head | Linear | 33.6 M\n", + "--------------------------------------\n", + "87.6 M Trainable params\n", + "0 Non-trainable params\n", + "87.6 M Total params\n", + "350.405 Total estimated model params size (MB)\n", + "Epoch 0: 16%|▏| 1600/9892 [00:55<04:49, 28.62it/s, v_num=mu7h, train/loss=5.310/home/picocreator/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/torch/nn/modules/module.py:1879: UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n", + " warnings.warn(\n", + "Epoch 0: 100%|█| 9892/9892 [05:45<00:00, 28.62it/s, v_num=mu7h, train/loss=4.090\n", + "Validation: | | 0/? [00:00=12.1), as this is known to have freeze issues\n", + "# - The terms used in wandb / the progress bar can be confusing, see the github README.md for beter clarifications\n", + "# - When resuming from checkpoint, the estimated time is inaccurate\n", + "#\n", + "\n", + "[RWKV.model] Configuring optimizer with\n", + " - lr_init: 8.000e-04 (0.0008)\n", + " - lr_final: 4.000e-04 (0.0004)\n", + "\n", + "Using /home/picocreator/.cache/torch_extensions/py311_cu121 as PyTorch extensions root...\n", + "Detected CUDA files, patching ldflags\n", + "Emitting ninja build file /home/picocreator/.cache/torch_extensions/py311_cu121/fused_adam/build.ninja...\n", + "Building extension module fused_adam...\n", + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", + "ninja: no work to do.\n", + "Loading extension module fused_adam...\n", + "Time to load fused_adam op: 0.05180692672729492 seconds\n", + "/home/picocreator/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/deepspeed/ops/adam/fused_adam.py:96: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/conda/conda-bld/pytorch_1699449201450/work/torch/csrc/tensor/python_tensor.cpp:83.)\n", + " self._dummy_overflow_buf = get_accelerator().IntTensor([0])\n", + "Loading `train_dataloader` to estimate number of stepping batches.\n", + "\n", + " | Name | Type | Params\n", + "--------------------------------------\n", + "0 | emb | Embedding | 33.6 M\n", + "1 | blocks | ModuleList | 20.5 M\n", + "2 | ln_out | LayerNorm | 1.0 K \n", + "3 | head | Linear | 33.6 M\n", + "--------------------------------------\n", + "87.6 M Trainable params\n", + "0 Non-trainable params\n", + "87.6 M Total params\n", + "350.405 Total estimated model params size (MB)\n", + "Epoch 0: 16%|▏| 800/4946 [00:35<03:05, 22.41it/s, v_num=3o87, train/loss=5.060]/home/picocreator/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/torch/nn/modules/module.py:1879: UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n", + " warnings.warn(\n", + "Epoch 0: 100%|█| 4946/4946 [03:42<00:00, 22.19it/s, v_num=3o87, train/loss=5.720\n", + "Validation: | | 0/? [00:00=12.1), as this is known to have freeze issues\n", + "# - The terms used in wandb / the progress bar can be confusing, see the github README.md for beter clarifications\n", + "# - When resuming from checkpoint, the estimated time is inaccurate\n", + "#\n", + "\n", + "[RWKV.model] Configuring optimizer with\n", + " - lr_init: 8.000e-04 (0.0008)\n", + " - lr_final: 4.000e-04 (0.0004)\n", + "\n", + "Using /home/picocreator/.cache/torch_extensions/py311_cu121 as PyTorch extensions root...\n", + "Detected CUDA files, patching ldflags\n", + "Emitting ninja build file /home/picocreator/.cache/torch_extensions/py311_cu121/fused_adam/build.ninja...\n", + "Building extension module fused_adam...\n", + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", + "ninja: no work to do.\n", + "Loading extension module fused_adam...\n", + "Time to load fused_adam op: 0.05039358139038086 seconds\n", + "/home/picocreator/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/deepspeed/ops/adam/fused_adam.py:96: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/conda/conda-bld/pytorch_1699449201450/work/torch/csrc/tensor/python_tensor.cpp:83.)\n", + " self._dummy_overflow_buf = get_accelerator().IntTensor([0])\n", + "Loading `train_dataloader` to estimate number of stepping batches.\n", + "\n", + " | Name | Type | Params\n", + "--------------------------------------\n", + "0 | emb | Embedding | 33.6 M\n", + "1 | blocks | ModuleList | 20.5 M\n", + "2 | ln_out | LayerNorm | 1.0 K \n", + "3 | head | Linear | 33.6 M\n", + "--------------------------------------\n", + "87.6 M Trainable params\n", + "0 Non-trainable params\n", + "87.6 M Total params\n", + "350.405 Total estimated model params size (MB)\n", + "Epoch 0: 16%|▏| 400/2473 [00:30<02:37, 13.12it/s, v_num=jp9a, train/loss=6.780]/home/picocreator/anaconda3/envs/rwkv-infctx/lib/python3.11/site-packages/torch/nn/modules/module.py:1879: UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n", + " warnings.warn(\n", + "Epoch 0: 100%|█| 2473/2473 [03:04<00:00, 13.41it/s, v_num=jp9a, train/loss=6.660\n", + "Validation: | | 0/? [00:00