Skip to content

Commit

Permalink
experimental token based dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCreator committed Jan 18, 2024
1 parent 86808e9 commit ce4a461
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions RWKV-v5/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def __init__(self,
position_loss_bias_in_validation: bool = False,

# Selective loss settings
selective_token_loss_threshold: float = 0.0,
token_loss_threshold: float = 0.0,
token_dropout_rate: float = 0.0, # Dropout rate should be between 0-1

# Backprop settings
grad_cp: bool = True,
Expand Down Expand Up @@ -295,7 +296,8 @@ def __init__(self,
# Save the position loss params, and selective loss settings
self.position_loss_bias = position_loss_bias
self.position_loss_bias_in_validation = position_loss_bias_in_validation
self.selective_token_loss_threshold = selective_token_loss_threshold
self.token_loss_threshold = token_loss_threshold
self.token_dropout_rate = token_dropout_rate

dim_att = dim_att or n_embd
dim_ffn = dim_ffn or int((n_embd * 3.5) // 32 * 32)
Expand Down Expand Up @@ -922,15 +924,24 @@ def checkpointed_step(idx, targets, mask, last_shift_states,
train_token_count = 0
train_mask = submask

elif self.selective_token_loss_threshold > 0.0:
elif self.token_loss_threshold > 0.0 or self.token_dropout_rate > 0.0:

# Sample loss, without backprop
with torch.no_grad():
sample_loss = (torch.sum(token_loss * submask) / total_mask_sum).clone().detach().requires_grad_(False)

# Building the training mask
train_mask = submask

# Selective loss gating
above_threshold = token_loss > self.selective_token_loss_threshold
train_mask = submask * above_threshold
if self.token_loss_threshold > 0.0:
above_threshold = token_loss > self.token_loss_threshold
train_mask = train_mask * above_threshold

# Dropout logic
if self.token_dropout_rate > 0.0:
dropout_mask = torch.rand(train_mask.shape, device=train_mask.device) > self.token_dropout_rate
train_mask = train_mask * dropout_mask

# The training loss to use
train_loss = torch.sum(token_loss * train_mask) / total_mask_sum
Expand Down

0 comments on commit ce4a461

Please sign in to comment.