Skip to content

Commit

Permalink
ab-neg archs
Browse files Browse the repository at this point in the history
  • Loading branch information
Mickus Timothee committed Dec 18, 2023
1 parent 42e39a2 commit cfdd2d8
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 30 deletions.
2 changes: 1 addition & 1 deletion mammoth/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def forward(self, src, tgt, lengths, bptt=False, with_align=False, metadata=None

enc_state, memory_bank, lengths, mask = self.encoder(src, lengths)

memory_bank, alphas = self.attention_bridge(memory_bank, mask)
memory_bank, alphas = self.attention_bridge(memory_bank, mask, lengths)
if self.attention_bridge.is_fixed_length:
# turn off masking in the transformer decoder
lengths = None
Expand Down
143 changes: 115 additions & 28 deletions mammoth/modules/attention_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from mammoth.rmsnorm_torch import RMSNorm
from mammoth.modules.transformer_encoder import TransformerEncoderLayer

from mammoth.modules.multi_headed_attn import MultiHeadedAttention
from mammoth.modules.embeddings import PositionalEncoding


class BaseAttentionBridgeLayer(nn.Module):
Expand Down Expand Up @@ -38,6 +40,65 @@ def __init__(self, normalized_shape, ab_layer_norm_type):
def forward(self, input):
return self.norm(input) # possibly an nn.Identity

@classmethod
def from_opts(cls, opts):
"""Alternate constructor."""
return cls(
opts.model_dim,
opts.ab_layer_norm,
)


class OptionalResidualConnection(nn.Module):
"""Maybe apply a residual connection"""
def __init__(self, ab_residual_connection_mode):
super().__init__()
self.ab_residual_connection_mode = ab_residual_connection_mode

def forward(self, input, output, mask, lengths):
# case 1. there's a residual connection to apply and it's standard
if (
input.size() == output.size()
and self.ab_residual_connection_mode in {'same_size', 'average_uneven', 'exp_uneven'}
):
return input + output
# case 2. there's no residual connection to apply
elif self.ab_residual_connection_mode in {'none', 'same_size'}:
return output

# we're gonna need some weird magic to residually connext across different tensor shapes
# both cases 3. and 4. will differ in how they smear the input to match the output
B, S_i, H = input.shape
B_o, S_o, H_o = output.shape
assert B == B_o and H == H_o, 'tensor shapes differ by more than one dim'

# case 3. residual connection, but summing input sequence and smear that across output
if self.ab_residual_connection_mode in {'average_uneven', 'average_all'}:
if mask is not None:
input = input.masked_fill(mask.transpose(1, 2), 0.0)
smeared_ipt = input.sum(1, keepdim=True) / S_o
return output + smeared_ipt

# case 3. residual connection, but soft-match input & output indices
elif self.ab_residual_connection_mode in {'exp_uneven', 'exp_all'}:
# lengths - 1 so that positions end up between 0 and 1
inputs_hardpos = torch.arange(S_i, device=input.device)[None, :, None].expand(B, S_i, 1)
lengths = (lengths - 1)[:, None, None].expand(B, S_i, 1)
inputs_relpos = (inputs_hardpos / lengths).expand(B, S_i, S_o)
outputs_hardpos = torch.arange(S_o, device=input.device)[None, None, :].expand(B, 1, S_o)
outputs_relpos = (outputs_hardpos / (S_o - 1)).expand(B, S_i, S_o)
raw_relmatch = 1 - (outputs_relpos - inputs_relpos).abs()
if mask is not None:
# zero-out invalid positions
raw_relmatch = raw_relmatch.masked_fill(mask.transpose(1, 2), -float('inf'))
# Softmax to ensure the sum of what we distribute equals the sum of the inputs
# Temperature to make it peaky
exp_relmatch = F.softmax(raw_relmatch / 0.1, dim=1)
smeared_ipt = torch.einsum('BSH,BST->BTH', input, exp_relmatch)
return output + smeared_ipt
else:
raise RuntimeError("This is poorly implemented, you should've fixed it before I told you")


class PerceiverAttentionBridgeLayer(BaseAttentionBridgeLayer):
def __init__(
Expand Down Expand Up @@ -95,28 +156,27 @@ def forward(self, intermediate_output, encoder_output, mask=None):
mask: binary mask 1/0 indicating which keys have
zero/non-zero attention ``(batch, query_len, key_len)`` -> # [bsz, 1, len]
"""
S, B, F = encoder_output.shape
S, B, H = encoder_output.shape
if intermediate_output is not None:
cross_query = intermediate_output
else:
cross_query = self.latent_array.unsqueeze(0).expand(B, -1, -1)
encoder_output = encoder_output.transpose(0, 1)
encoder_output = self.cross_attention_norm(encoder_output.transpose(0, 1))

# sublayer 1: projects to fixed size
cross_attention_output, alphas = self.cross_attention_block(
encoder_output, encoder_output, cross_query, mask=mask, attn_type='context'
)
cross_attention_output = self.cross_attention_norm(cross_attention_output + cross_query)
cross_attention_output = self.cross_ff_block(cross_attention_output) + cross_attention_output
cross_attention_output = self.cross_ff_norm(cross_attention_output)
cross_attention_opt = cross_attention_output + cross_query
cross_attention_opt = self.cross_ff_block(self.cross_ff_norm(cross_attention_opt)) + cross_attention_opt

# sublayer 2: performs self-attention
cross_attention_opt = self.self_attention_norm(cross_attention_opt)
self_attention_output, _ = self.self_attention_block(
cross_attention_output, cross_attention_output, cross_attention_output, mask=None, attn_type='self'
cross_attention_opt, cross_attention_opt, cross_attention_opt, mask=None, attn_type='self'
)
self_attention_output = self.self_attention_norm(self_attention_output + cross_attention_output)
self_attention_output = self.self_ff_block(self_attention_output) + self_attention_output
self_attention_output = self.self_ff_norm(self_attention_output)
self_attention_output = self_attention_output + cross_attention_opt
self_attention_output = self.self_ff_block(self.self_ff_norm(self_attention_output)) + self_attention_output

return alphas, self_attention_output

Expand All @@ -134,7 +194,7 @@ def __init__(
hidden_ab_size,
model_type,
model_dim,
ab_layer_norm=None,
# ab_layer_norm=None,
):
"""Attention Heads Layer:"""
super(LinAttentionBridgeLayer, self).__init__()
Expand All @@ -150,8 +210,8 @@ def __init__(
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
self.attention_hops = r
self.M = None # TODO : remove
self.norm = AttentionBridgeNorm(d, ab_layer_norm)
# self.M = None # TODO : remove
# self.norm = AttentionBridgeNorm(d, ab_layer_norm)

@classmethod
def from_opts(cls, opts):
Expand All @@ -162,7 +222,7 @@ def from_opts(cls, opts):
opts.hidden_ab_size,
opts.model_type,
opts.model_dim,
opts.ab_layer_norm,
# opts.ab_layer_norm,
)

def forward(self, intermediate_output, encoder_output, mask=None):
Expand Down Expand Up @@ -192,11 +252,11 @@ def forward(self, intermediate_output, encoder_output, mask=None):
alphas = alphas.view(B, self.attention_hops, L) # [bsz, hop, len]
output = torch.bmm(alphas, intermediate_output)

output = self.norm(output)
# output = self.norm(output)
# TODO: why cache? not sure what else is looking at layer.M
self.M = torch.transpose(
output, 0, 1
).contiguous() # [r,bsz,nhid] torch.transpose(output, 0, 1).contiguous() #[r,bsz,nhid]
# self.M = torch.transpose(
# output, 0, 1
# ).contiguous() # [r,bsz,nhid] torch.transpose(output, 0, 1).contiguous() #[r,bsz,nhid]
return alphas, output

@property
Expand All @@ -210,15 +270,15 @@ class SimpleAttentionBridgeLayer(BaseAttentionBridgeLayer):
latent key space to produce coherent mixtures of value vectors.
"""

def __init__(self, input_size, hidden_size, fixed_seqlen, ab_layer_norm):
def __init__(self, input_size, hidden_size, fixed_seqlen):
super().__init__()
self.query_matrix = nn.Parameter(torch.zeros(fixed_seqlen, hidden_size))
self.keys_proj = nn.Linear(input_size, hidden_size)
self.values_proj = nn.Linear(input_size, input_size)
self.d_sqrt = hidden_size**0.5
self.R = fixed_seqlen
self.softmax = nn.Softmax(dim=-1)
self.norm = AttentionBridgeNorm(input_size, ab_layer_norm)
# self.norm = AttentionBridgeNorm(input_size, ab_layer_norm)

@property
def is_fixed_length(self):
Expand All @@ -240,7 +300,8 @@ def forward(self, intermediate_output, encoder_output, mask=None):
mask_reshaped = mask.view(B, 1, L)
raw_scores = raw_scores.masked_fill(mask_reshaped, -float('inf'))
attention_weights = self.softmax(raw_scores / self.d_sqrt)
output = self.norm(attention_weights @ values)
# output = self.norm(attention_weights @ values)
output = attention_weights @ values
return attention_weights, output

@classmethod
Expand All @@ -249,7 +310,7 @@ def from_opts(cls, opts):
opts.model_dim,
opts.hidden_ab_size,
opts.ab_fixed_length,
opts.ab_layer_norm,
# opts.ab_layer_norm,
)


Expand Down Expand Up @@ -285,19 +346,22 @@ def from_opts(cls, opts):
opts.dropout[0],
opts.attention_dropout[0],
max_relative_positions=opts.max_relative_positions,
pos_ffn_activation_fn=opts.pos_ffn_activation_fn,
# norm_first=True,
# batch_first=True,
)


class FeedForwardAttentionBridgeLayer(BaseAttentionBridgeLayer):
"""Simple feedforward bridge component"""

def __init__(self, input_size, hidden_size, ab_layer_norm):
def __init__(self, input_size, hidden_size):
super().__init__()
self.module = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, input_size),
AttentionBridgeNorm(input_size, ab_layer_norm),
# AttentionBridgeNorm(input_size, ab_layer_norm),
)

@property
Expand All @@ -317,7 +381,7 @@ def from_opts(cls, opts):
return cls(
opts.model_dim,
opts.hidden_ab_size,
opts.ab_layer_norm,
# opts.ab_layer_norm,
)


Expand All @@ -326,11 +390,15 @@ class AttentionBridge(nn.Module):
N-layered attention-bridge between encoders->decoders
"""

def __init__(self, layers):
def __init__(self, layers, layer_norms, residual_connection, pos_encoding, final_norm):
"""Attention Heads Layer"""
super(AttentionBridge, self).__init__()
self.layers = nn.ModuleList(layers)
self.layer_norms = nn.ModuleList(layer_norms)
self.residual_connection = residual_connection
self.is_fixed_length = any(x.is_fixed_length for x in layers)
self.pos_encoding = pos_encoding
self.final_norm = final_norm

@classmethod
def from_opts(cls, opts):
Expand All @@ -346,6 +414,11 @@ def from_opts(cls, opts):

# preconstruct layers using .from_opts(...)
layers = [layer_type_to_cls[layer_type].from_opts(opts) for layer_type in opts.ab_layers]
layer_norms = [
AttentionBridgeNorm.from_opts(opts) if layer_type not in {'transformer', 'perceiver'}
else nn.Identity()
for layer_type in opts.ab_layers
]

# FIXME: locking-in edge case behavior
if any(layer == 'perceiver' for layer in opts.ab_layers):
Expand All @@ -359,21 +432,35 @@ def from_opts(cls, opts):
for perceiver_layer in layers[1:]:
perceiver_layer.latent_array = None

return cls(layers)
residual_connection = OptionalResidualConnection(opts.ab_residual_connection_mode)
pos_encoding = nn.Identity()
if opts.ab_final_pos_enc and any(layer.is_fixed_length for layer in layers):
pos_encoding = PositionalEncoding(opts.dropout[0], opts.model_dim, max_len=opts.ab_fixed_length)
final_norm = AttentionBridgeNorm.from_opts(opts) if opts.ab_final_norm else nn.Identity()

return cls(layers, layer_norms, residual_connection, pos_encoding, final_norm)

def forward(self, enc_output, mask):
def forward(self, enc_output, mask, lengths):
"""Forward pass for the bridge layers"""
out = enc_output.transpose(0, 1)
if self.layers and isinstance(self.layers[0], PerceiverAttentionBridgeLayer):
out = None
alphas = None
orig_mask = mask
for layer in self.layers:
for layer, layer_norm in zip(self.layers, self.layer_norms):
mask_ = orig_mask if isinstance(layer, PerceiverAttentionBridgeLayer) else mask
trace = out = layer_norm(out)
alphas, out = layer(out, enc_output, mask_)
if not (
isinstance(layer, PerceiverAttentionBridgeLayer)
or isinstance(layer, TransformerAttentionBridgeLayer)
): # handle residual connections natively
self.residual_connection(trace, out, mask_, lengths)
if layer.is_fixed_length:
# In this case, we've ensured all batch items have a constant
# sequence length, so the mask is no longer required.
mask = None
out = torch.transpose(out, 0, 1).contiguous()
out = self.pos_encoding(out)
out = self.final_norm(out)
return out, alphas # [hop, bsz, nhid], [bsz, hop, srcseqlen]
3 changes: 2 additions & 1 deletion mammoth/modules/layer_stack_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def from_opts(cls, opts, embeddings, task_queue_manager):
encoders = nn.ModuleList()
for layer_stack_index, n_layers in enumerate(opts.enc_layers):
stacks = nn.ModuleDict()
is_on_top = layer_stack_index == len(opts.enc_layers) - 1
is_on_top = layer_stack_index == len(opts.enc_layers) - 1
is_on_top = is_on_top and not opts.ab_layers
for module_id in task_queue_manager.get_encoders(layer_stack_index):
if module_id in stacks:
# several tasks using the same layer stack
Expand Down
12 changes: 12 additions & 0 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,18 @@ def model_opts(parser):
choices=['none', 'rmsnorm', 'layernorm'],
help="""Use layer normalization after lin, simple and feedforward bridge layers""",
)
group.add('--ab_final_norm', '-ab_final_norm', action='store_true', help='normalize AB output')
# group.add('--enc_final_norm', '-enc_final_norm', action='store_true', help='normalize enc output')
group.add(
'--ab_residual_connection_mode',
'-ab_residual_connection_mode',
type=str,
default='none',
choices=['none', 'same_size', 'average_uneven', 'average_all', 'exp_uneven', 'exp_all'],
help='apply residual connections & define algorithm for uneven input/output matrices',
)
group.add('--ab_zero_init', '-ab_zero_init', action='store_true', help='initilize AB params to 0')
group.add('--ab_final_pos_enc', '-ab_final_pos_enc', action='store_true', help='add positional encodings.')

# adapter options are in a dict "adapters", and in the corpus options
group = parser.add_argument_group("Adapters")
Expand Down

0 comments on commit cfdd2d8

Please sign in to comment.