Skip to content

Commit

Permalink
Merge pull request #75 from RWKV/rwkv-x-eagle-notebooks
Browse files Browse the repository at this point in the history
Rwkv x eagle notebooks
  • Loading branch information
PicoCreator authored Feb 6, 2024
2 parents d606d4a + e02907a commit 6b04127
Show file tree
Hide file tree
Showing 53 changed files with 8,710 additions and 766 deletions.
5 changes: 5 additions & 0 deletions RWKV-v5/config-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,10 @@ data:
# source: "teven/enwiki_00k" # Hugging face dataset
# source: text # Text mode, used with source_data_dir

# Dataset split to use from HF dataset
# ---
# source_dataset_split: train

# Additional source dataset params, used to grab subsets of the dataset
# ---
# source_dataset_params:
Expand Down Expand Up @@ -419,6 +423,7 @@ data:

# Custom text column to use, useful for dataset with alternative training columns labels
# This is checked before multi column merging, default is null (disabled)
# If set this takes priority
# eg: 'code'
# ---
# custom_text_key: 'code'
Expand Down
11 changes: 8 additions & 3 deletions RWKV-v5/datapack-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,20 @@ datapack:

# Mixing mode to use, this is used to alternate between datasets
#
# - batch : Meaning one dataset worth per batch, partial batches are discarded
# - sample : Dataset is mixed on a per sample level
mixing_mode: "batch"
# - concat : Keep It Simple Silly, lets just concat the datasets together
# - shuffle : Dataset is mixed on a per sample level
#
# (@TODO: Advance operations)
# - batch : Meaning one dataset worth per batch, partial batches are discarded
mixing_mode: "shuffle"

# # Mixing distribution to use
# # - weighted : Dataset batches/mixture is distrbuted randomly, but weighted by dataset size
# # - uniform : Dataset batches/mixture is distrbuted randomly, but with uniform probability
# distribution: "weighted"

# (@TODO: Advance operations)
#
# Mixed batch percentage
#
# % of batches which will contain a mixture of records from multiple datasets
Expand Down
418 changes: 232 additions & 186 deletions RWKV-v5/src/data.py

Large diffs are not rendered by default.

83 changes: 43 additions & 40 deletions RWKV-v5/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,14 +812,14 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool = False, is_valid
assert isinstance(seq, torch.Tensor) and seq.ndim == 2
ori_seq_mask = batch['attention_mask']

# Get the dataset index
dataset_index = 0
dataset_name = "dataset_0"
if "dataset_index" in batch:
dataset_index = batch["dataset_index"]
dataset_name = f"dataset_{dataset_index}"
if "dataset_name" in batch and dataset_name is not None:
dataset_name = batch["dataset_name"]
# # Get the dataset index
# dataset_index = 0
# dataset_name = "dataset_0"
# if "dataset_index" in batch:
# dataset_index = batch["dataset_index"]
# dataset_name = f"dataset_{dataset_index}"
# if "dataset_name" in batch and dataset_name is not None:
# dataset_name = batch["dataset_name"]

# Check if attent mask is set, if not initialize it
if ori_seq_mask is None or ori_seq_mask.ndim != 2:
Expand Down Expand Up @@ -913,21 +913,24 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool = False, is_valid

# If total_mask_sum, we skip, as there is no tokens of value to learn from anyway
total_mask_sum = torch.sum(seq_mask)
# Do a quick return, if there is no tokens of value to learn from due to full masking
if num_devices > 1 and total_mask_sum == 0:
return 0
avg_mask_sum = ( total_mask_sum / B )

# # Do a quick return, if there is no tokens of value to learn from due to full masking
# # DO NOT DO THIS : This causes multi node / multi GPU to go out of sync
# if num_devices <= 1 and total_mask_sum == 0:
# return 0

# Checkpoint steps
def checkpointed_step(idx, targets, mask, last_shift_states,
last_wkv_states):
# Skip if there is no tokens of value to learn from
if idx.shape[1] == 0:
# Prepare dummy loss
train_loss = torch.tensor(0, dtype=self.emb.weight.dtype).requires_grad_()
sample_loss = train_loss.clone().detach().requires_grad_(False)
# # Skip if there is no tokens of value to learn from
# if idx.shape[1] == 0:
# # Prepare dummy loss
# train_loss = torch.tensor(0, dtype=self.emb.weight.dtype).requires_grad_()
# sample_loss = train_loss.clone().detach().requires_grad_(False)

# Return the checkpoint values
return sample_loss, train_loss, last_shift_states, last_wkv_states, 0
# # Return the checkpoint values
# return sample_loss, train_loss, last_shift_states, last_wkv_states, 0

# Get the logits, and the new states
logits, new_shift_states, new_wkv_states = self(
Expand All @@ -947,7 +950,7 @@ def checkpointed_step(idx, targets, mask, last_shift_states,

# to encourage the logits to be close to 0
# factor_divisor is typically the total token count
L2Wrap_factor = 1e-4 / total_mask_sum
L2Wrap_factor = 1e-4 / avg_mask_sum

# Submask count
submask_count = torch.sum(submask)
Expand Down Expand Up @@ -983,7 +986,7 @@ def checkpointed_step(idx, targets, mask, last_shift_states,
train_token_count = torch.sum(train_mask)

# Adjust the factor accordingly
L2Wrap_factor = L2Wrap_factor * (submask_count / train_token_count)
# L2Wrap_factor = L2Wrap_factor * (submask_count / train_token_count)

else:
train_loss = torch.sum(token_loss * submask) / total_mask_sum
Expand Down Expand Up @@ -1254,20 +1257,20 @@ def checkpointed_step(idx, targets, mask, last_shift_states,
# Log the line values
wandb.log({
# The original loss and ctx_len (averaged by batch size)
'train/ctx_len': ctx_len,
'train/data_ctxlen': ctx_len,
'train/data_loss': sampling_loss,
# "train/dataset_index": dataset_index,

# The selective training tokens, and loss
'train/tokens': tokens,
'train/loss': training_loss,
"train/dataset_index": dataset_index,
'train/learn_tokens': tokens,
'train/learn_loss': training_loss,

# Dataset based tracking
f'dataset/train/{dataset_index}.loss': training_loss,
f'dataset/train/{dataset_index}.data_loss': sampling_loss,
f'dataset/train/{dataset_index}.tokens': tokens,
f'dataset/train/{dataset_index}.ctx_len': ctx_len,
f'dataset/train/{dataset_index}.name': dataset_name,
# # Dataset based tracking (not working)
# f'dataset/train/{dataset_index}.loss': training_loss,
# f'dataset/train/{dataset_index}.data_loss': sampling_loss,
# f'dataset/train/{dataset_index}.tokens': tokens,
# f'dataset/train/{dataset_index}.ctx_len': ctx_len,
# f'dataset/train/{dataset_index}.name': dataset_name,

# Perf tracking
f'perf/kTokens_per_sec.gpu.{global_rank}': self._counting_tokens / max(time.time() - self._counting_time_start, 1),
Expand All @@ -1286,19 +1289,19 @@ def checkpointed_step(idx, targets, mask, last_shift_states,
# Log the line values
wandb.log({
# The original loss and ctx_len (averaged by batch size)
'validation/ctx_len': T,
'validation/data_ctxlen': T,
'validation/data_loss': sampling_loss,
# "validation/dataset_index": dataset_index,

# The selective training tokens, and loss
'validation/tokens': training_tokens,
'validation/loss': training_loss,
"validation/dataset_index": dataset_index,

# Dataset based tracking
f'dataset/validation/{dataset_index}.loss': training_loss,
f'dataset/validation/{dataset_index}.data_loss': sampling_loss,
f'dataset/validation/{dataset_index}.ctx_len': T,
f'dataset/validation/{dataset_index}.name': dataset_name,
'validation/learn_tokens': training_tokens,
'validation/learn_loss': training_loss,

# # Dataset based tracking (not working)
# f'dataset/validation/{dataset_index}.loss': training_loss,
# f'dataset/validation/{dataset_index}.data_loss': sampling_loss,
# f'dataset/validation/{dataset_index}.ctx_len': T,
# f'dataset/validation/{dataset_index}.name': dataset_name,

# Step and trainer tracking
'global_rank': global_rank,
Expand Down
6 changes: 3 additions & 3 deletions docker/github-worker-cuda-11-8/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ else
--token "${RUNNER_TOKEN}" \
--name "${RUNNER_NAME}" \
--replace \
--labels "nolane,${CUDA_VER},${RUNNER_LABELS}"
--labels "nolane,any-gpu,gpu-count-any,${CUDA_VER},${RUNNER_LABELS}"

# Run it in background, and get the PID
./run.sh &
Expand All @@ -41,7 +41,7 @@ else
--token "${RUNNER_TOKEN}" \
--name "${RUNNER_NAME}-lane1" \
--replace \
--labels "lane1,${CUDA_VER},${RUNNER_LABELS}"
--labels "lane1,any-gpu,gpu-count-any,${CUDA_VER},${RUNNER_LABELS}"

# Run it in background, and get the PID
./run.sh &
Expand All @@ -55,7 +55,7 @@ else
--token "${RUNNER_TOKEN}" \
--name "${RUNNER_NAME}-lane2" \
--replace \
--labels "lane2,${CUDA_VER},${RUNNER_LABELS}"
--labels "lane2,any-gpu,gpu-count-any,${CUDA_VER},${RUNNER_LABELS}"

# Run it in background, and get the PID
./run.sh &
Expand Down
4 changes: 2 additions & 2 deletions notebook/finetune-example/Eagle-x-ALMA-prompt-completion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ model:
load_model: ../model/L6-D512-neox-init.pth

# Starting and ending learning rate
lr_init: 1e-5
lr_final: 1e-5
lr_init: 3e-5
lr_final: 3e-5

# Training context length, note that the dataset can be
# larger then the context size, in which the trainer
Expand Down
4 changes: 2 additions & 2 deletions notebook/finetune-example/Eagle-x-capybara-chat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ model:
load_model: ../model/L6-D512-neox-init.pth

# Starting and ending learning rate
lr_init: 1e-5
lr_final: 1e-5
lr_init: 3e-5
lr_final: 3e-5

# Training context length, note that the dataset can be
# larger then the context size, in which the trainer
Expand Down
4 changes: 2 additions & 2 deletions notebook/finetune-example/Eagle-x-openhermes1-instruct.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ model:
load_model: ../model/L6-D512-neox-init.pth

# Starting and ending learning rate
lr_init: 1e-5
lr_final: 1e-5
lr_init: 3e-5
lr_final: 3e-5

# Training context length, note that the dataset can be
# larger then the context size, in which the trainer
Expand Down
4 changes: 2 additions & 2 deletions notebook/finetune-example/Eagle-x-textbooks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ model:
load_model: ../model/L6-D512-neox-init.pth

# Starting and ending learning rate
lr_init: 1e-5
lr_final: 1e-5
lr_init: 3e-5
lr_final: 3e-5

# Training context length, note that the dataset can be
# larger then the context size, in which the trainer
Expand Down
87 changes: 56 additions & 31 deletions notebook/finetune-example/Eagle-x-zMultipack-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,14 @@ datapack:
# secret: <example S3 secret>
# endpoint_url: <example S3 endpoint>

# Batch size to use to alternate between datasets
# This should be a multiple of the GPU and node count
#
# Uses, `8 * (3 * 4 * 5 * 6 * 7) = 20160` for default, as it should align across
# a large number of batch size combinations. This helps reduce the amount of
# misaligned batches, and thus reduce the amount of wasted training time
batchsize: 512

# Mixing mode to use, this is used to alternate between datasets
#
# - batch : Meaning one dataset worth per batch, partial batches are discarded
# - sample : Dataset is mixed on a per sample level
mixing_mode: "batch"

# # Mixing distribution to use
# # - weighted : Dataset batches/mixture is distrbuted randomly, but weighted by dataset size
# # - uniform : Dataset batches/mixture is distrbuted randomly, but with uniform probability
# distribution: "weighted"

# Mixed batch percentage
#
# % of batches which will contain a mixture of records from multiple datasets
# instad of limiting each batch to a single dataset
# - concat : Keep It Simple Silly, lets just concat the datasets together
# - shuffle : Dataset is mixed on a per sample level
#
# Use 0, to disable mixed batches, sampled mixing_mode is the equavalent of mixed batch 1.0
#
# NOTE: This is a guideline percentage, and is not guaranteed to be exact
# if a partial batch is built, it may get converted to a mixed batch
mixed_batch_percentage: 0.5
# (@TODO: Advance operations)
# - batch : Meaning one dataset worth per batch, partial batches are discarded
mixing_mode: "shuffle"

#
# Default settings used across all datasets in the datapack
Expand Down Expand Up @@ -115,7 +94,7 @@ default:
# If given an int value, the number of data sample is used.
#
# Due to the limitaitons in the trainer process, there is always a minimum of 1 test sample
test_split: 8 # Intentionally set to a low sample for test, cause the real eval is humans
test_split: 0.01 # Intentionally set to a low sample for test, cause the real eval is humans
test_split_shuffle: true

# Tokenizer to use, use either the inbuilt 'neox', or 'world' tokenizer
Expand Down Expand Up @@ -265,10 +244,14 @@ dataset:

- # Text book is all you need
# https://huggingface.co/datasets/TanvirOnHF/muse_textbooks
source: "TanvirOnHF/muse_textbooks"
source: "teven/enwiki_100k"

# Optional, provide a name for the dataset
name: "muse_textbooks"
name: "enwiki_100k"

# Minimum / Maximum token size of the dataset to use
min_token_size: 1024
max_token_size: -1

# Various over write settings
# ---
Expand All @@ -277,6 +260,29 @@ dataset:
packing_enable: False
max_token_size: -1

- # SuperWiki (Multi-lingual)
# https://huggingface.co/datasets/RyokoExtra/SuperWIKI-Cleaned
source: "RyokoExtra/SuperWIKI-Cleaned"

# Optional, provide a name for the dataset
name: "super_wiki"

# Various over write settings
# ---
text_rechunk_size: 32768
text_rechunk_force: true
packing_enable: False
max_token_size: -1

source_dataset_split: lang25

# Custom text column to use, useful for dataset with alternative training columns labels
# This is checked before multi column merging, default is null (disabled)
# If set this takes priority
# eg: 'code'
# ---
custom_text_key: 'text'

# All other settings found in default can be overriden here
# ---
# ...
Expand All @@ -297,8 +303,7 @@ dataset:
# https://huggingface.co/datasets/kristaller486/ALMA-prompt-completion
source: "kristaller486/ALMA-prompt-completion"
name: "ALMA-prompt-completion"

# Prompt completion, notiong else
# Prompt completion, nothing else else

- # Instruct, input, output format
# https://huggingface.co/datasets/teknium/openhermes
Expand Down Expand Up @@ -380,6 +385,26 @@ dataset:
conversation_input_key_mask: {'input': false, 'output': true}
conversation_sender_suffix: {'input': "", 'output': ""}

######################################################
# Note: We found the ML generated textbooks
# too low in perplexity that it hurts the model
# so we are using the original enwiki_100k & superwiki
######################################################
# - # Text book is all you need
# # https://huggingface.co/datasets/TanvirOnHF/muse_textbooks
# source: "TanvirOnHF/muse_textbooks"

# # Optional, provide a name for the dataset
# name: "muse_textbooks"

# # Various over write settings
# # ---
# text_rechunk_size: 32768
# text_rechunk_force: true
# packing_enable: False
# max_token_size: -1
######################################################




Expand Down
Loading

0 comments on commit 6b04127

Please sign in to comment.