Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Miscellaneous fixes to the x-transformers implementation #79

Merged
merged 15 commits into from
Nov 4, 2024
65 changes: 59 additions & 6 deletions mammoth/distributed/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,50 @@ def needs_communication(self) -> bool:
return self.group is not None


# TODO: This is a misnomer: Not an entire XCoder, but just one AttentionLayers block
@dataclass # type: ignore
class DistributedTransformerWrapper(DistributedComponent, ABC):
task_id: str
side: Side

def get_name(self) -> str:
return f'{self.side.name}_{self.task_id}'

def get_module(self, model: NMTModel) -> nn.Module:
parent = model.encoder if self.side == Side.encoder else model.decoder
tw = parent[self.task_id]
return tw

def named_parameters(self, model: NMTModel):
module = self.get_module(model)
for name, p in module.named_parameters():
# TransformerWrapper contains the AttentionLayers and the embs.
# however, we want to treat these as distinct DistributedComponents
if name.startswith('attn_layers.'):
continue
if name.startswith('token_emb.'):
continue
yield name, p

def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, Any]:
module = self.get_module(model)
destination: Dict[str, Any] = OrderedDict()
for name, sub_module in module._modules.items():
if name.endswith('attn_layers'):
# stored separately
continue
sub_module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
return destination

def load_state_dict(self, model: NMTModel, state_dict: Dict[str, Any]):
module = self.get_module(model)
mismatch = module.load_state_dict(state_dict, strict=False)
missing_keys = [
name for name in mismatch.missing_keys
if not name.startswith('attn_layers.') or name.startswith('token_emb.')
]
return mismatch._replace(missing_keys=missing_keys)


@dataclass # type: ignore
class DistributedAttentionLayersBlock(DistributedComponent, ABC):
layer_stack_index: int
Expand All @@ -106,22 +149,32 @@ def named_parameters(self, model: NMTModel):
for name, p in module.named_parameters():
# encoders and decoders contain embeddings and adapters as submodules
# however, we want to treat these as distinct DistributedComponents
if 'embeddings' not in name and 'adapter' not in name:
yield name, p
if 'adapter' in name:
continue
yield name, p

def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, Any]:
module = self.get_module(model)
destination: Dict[str, Any] = OrderedDict()
for name, sub_module in module._modules.items():
for name, sub_module in module.get_sub_modules().items():
if name == 'adapters':
# Adapters are stored separately
continue
sub_module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
return destination

def load_state_dict(self, model: NMTModel, state_dict: Dict[str, Any]):
module = self.get_module(model)
mismatch = module.load_state_dict(state_dict, strict=False)
missing_keys = [
name for name in mismatch.missing_keys
if not name.startswith('layers.')
]
return mismatch._replace(missing_keys=missing_keys)


@dataclass
class DistributedEncoder(DistributedAttentionLayersBlock):
class DistributedEncoderAttentionLayersBlock(DistributedAttentionLayersBlock):
@property
def side(self) -> Side:
return Side.encoder
Expand All @@ -136,7 +189,7 @@ def get_module(self, model: NMTModel) -> nn.Module:


@dataclass
class DistributedDecoder(DistributedAttentionLayersBlock):
class DistributedDecoderAttentionLayersBlock(DistributedAttentionLayersBlock):
@property
def side(self) -> Side:
return Side.decoder
Expand Down
27 changes: 23 additions & 4 deletions mammoth/distributed/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
DistributedComponent,
DistributedComponentBuilder,
DistributedComponentGradientSync,
DistributedDecoder,
DistributedDecoderAttentionLayersBlock,
DistributedEmbedding,
DistributedEncoder,
DistributedEncoderAttentionLayersBlock,
DistributedTransformerWrapper,
Side,
)
from mammoth.distributed.contexts import DeviceContext, WorldContext
Expand Down Expand Up @@ -369,9 +370,27 @@ def create_all_distributed_components(
lang=task.tgt_lang,
)
)
builder.add(
DistributedTransformerWrapper(
global_ranks={global_rank},
task_ids={task.corpus_id},
group=None,
side=Side.encoder,
task_id=task.corpus_id,
)
)
builder.add(
DistributedTransformerWrapper(
global_ranks={global_rank},
task_ids={task.corpus_id},
group=None,
side=Side.decoder,
task_id=task.corpus_id,
)
)
for layer_stack_index, encoder_id in enumerate(task.encoder_id):
builder.add(
DistributedEncoder(
DistributedEncoderAttentionLayersBlock(
global_ranks={global_rank},
task_ids={task.corpus_id},
group=None,
Expand All @@ -381,7 +400,7 @@ def create_all_distributed_components(
)
for layer_stack_index, decoder_id in enumerate(task.decoder_id):
builder.add(
DistributedDecoder(
DistributedDecoderAttentionLayersBlock(
global_ranks={global_rank},
task_ids={task.corpus_id},
group=None,
Expand Down
68 changes: 53 additions & 15 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from mammoth.distributed.components import (
DistributedAdapter,
DistributedComponent,
DistributedDecoder,
DistributedEncoder,
DistributedDecoderAttentionLayersBlock,
DistributedEncoderAttentionLayersBlock,
Side,
)
from mammoth.modules.adapters import (
Expand All @@ -31,6 +31,14 @@
from mammoth.utils.logging import logger
from mammoth.utils.misc import use_gpu

TRANSFORMER_WRAPPER_OPTS = {
'post_emb_norm',
'tie_embedding',
'use_abs_pos_emb',
'scaled_sinu_pos_emb',
'emb_frac_gradient',
}


def _combine_ordered_dicts(input_dicts: Dict[str, OrderedDict]) -> OrderedDict:
result = []
Expand Down Expand Up @@ -59,17 +67,34 @@ def get_attention_layers_kwargs(
is_last = layer_stack_index == len(depths) - 1
pre_norm_has_final_norm = is_last
kwargs = model_opts.x_transformers_opts if model_opts.x_transformers_opts else dict()
kwargs = {key: val for key, val in kwargs.items() if key not in TRANSFORMER_WRAPPER_OPTS}
kwargs.update({
'dim': model_opts.model_dim,
'depth': depth,
'heads': model_opts.heads,
'causal': causal,
'cross_attend': cross_attend,
'pre_norm_has_final_norm': pre_norm_has_final_norm,
})
return kwargs


def get_transformer_wrapper_kwargs(
side: Side,
model_opts,
):
"""Return arguments for x_transformers.TransformerWrapper"""
assert side in {Side.encoder, Side.decoder}, f'Invalid side "{side}"'
kwargs = model_opts.x_transformers_opts if model_opts.x_transformers_opts else dict()
kwargs = {key: val for key, val in kwargs.items() if key in TRANSFORMER_WRAPPER_OPTS}
max_seq_len = 0 if model_opts.max_length is None else model_opts.max_length
kwargs.update({
'max_seq_len': max_seq_len,
})
if side == Side.encoder:
kwargs['return_only_embed'] = True
return kwargs


def build_xcoder(
side: Side,
model_opts,
Expand All @@ -96,10 +121,10 @@ def build_xcoder(
]
distributed_xcoder_class: type
if side == Side.encoder:
distributed_xcoder_class = DistributedEncoder
distributed_xcoder_class = DistributedEncoderAttentionLayersBlock
side_str = 'encoder'
else:
distributed_xcoder_class = DistributedDecoder
distributed_xcoder_class = DistributedDecoderAttentionLayersBlock
side_str = 'decoder'
if single_task:
my_components = [
Expand Down Expand Up @@ -197,6 +222,10 @@ def build_xcoder(
if single_task:
tasks = [task for task in tasks if task.corpus_id == single_task]
transformer_wrappers = dict()
transformer_wrapper_kwargs = get_transformer_wrapper_kwargs(
side=side,
model_opts=model_opts,
)
for task in tasks:
if side == Side.encoder:
xcoder_ids = task.encoder_id
Expand All @@ -212,22 +241,13 @@ def build_xcoder(

lang = task.src_lang if side == Side.encoder else task.tgt_lang
vocab = vocabs_dict[(side_alt_str, lang)]
max_seq_len = 0 if model_opts.max_length is None else model_opts.max_length
post_emb_norm = True
tie_embedding = True
use_abs_pos_emb = True
emb_frac_gradient = 1.
# Using custom extended TransformerWrapper to allow passing in an embedding
transformer_wrapper = TransformerWrapper(
num_tokens=len(vocab),
max_seq_len=max_seq_len,
attn_layers=adapted_attention_layers_stack,
emb_dim=model_opts.model_dim,
post_emb_norm=post_emb_norm,
tie_embedding=tie_embedding,
use_abs_pos_emb=use_abs_pos_emb,
emb_frac_gradient=emb_frac_gradient,
token_emb=token_embs[lang],
**transformer_wrapper_kwargs,
)
transformer_wrappers[task.corpus_id] = transformer_wrapper

Expand Down Expand Up @@ -310,3 +330,21 @@ def build_model(
# logger.info(model)
logger.info('Building model - done!')
return model


def validate_optimizer_coverage(model, optimizer):
trainable_model_params = {
name: p for name, p in model.named_parameters()
if p.requires_grad
}
optimized_params = set()
for group in optimizer.param_groups:
optimized_params.update(group['params'])
missing_params = [
name for name, p in trainable_model_params.items()
if p not in optimized_params
]
if len(missing_params) > 0:
raise Exception(f'Missing optimizer for params: {sorted(missing_params)}')
else:
logger.info('All non-frozen parameters have an optimizer')
8 changes: 4 additions & 4 deletions mammoth/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def forward(self, src, decoder_input, src_mask, metadata=None):
return_embeddings=True,
)

encoder_output, alphas = self.attention_bridge(encoder_output, src_mask)
if self.attention_bridge.is_fixed_length:
# turn off masking in the transformer decoder
src_mask = None
# encoder_output, alphas = self.attention_bridge(encoder_output, src_mask)
# if self.attention_bridge.is_fixed_length:
# # turn off masking in the transformer decoder
# src_mask = None

retval = active_decoder(
decoder_input,
Expand Down
9 changes: 8 additions & 1 deletion mammoth/modules/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,13 @@ def _inject_adapters(self):
adapted_layer_types = []
adapted_layers = nn.ModuleList()
adapted_layer_dropouts = []
adapter_layers_by_index = self._merge_active_adapters()
i = 0
for layer_type, layer_struct, layer_dropout in zip(
self._base_layer_types,
self._base_layers,
self._base_layer_dropouts,
):
adapter_layers_by_index = self._merge_active_adapters()
if layer_type == 'f':
# Adapters apply to feedforward layers
adapter_layers = adapter_layers_by_index[i]
Expand All @@ -225,3 +225,10 @@ def _inject_adapters(self):
def forward(self, *args, **kwargs):
self._inject_adapters()
return super().forward(*args, **kwargs)

def get_sub_modules(self):
omit_submodules = {'layers'}
return {
name: sub_module for name, sub_module in self._modules.items()
if name not in omit_submodules
}
2 changes: 1 addition & 1 deletion mammoth/modules/layer_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class AdaptedAttentionLayersStack(nn.Module):
"""
Wrapper that allows stacking multiple AdaptedAttentionLayers.
Represents one particular stacking: does not allow switching out entire layers
Represents one particular task-specific stacking: does not allow switching out entire layers
(but does delegate the switching out of adapters to its components)
"""
def __init__(self, attention_layers_stack: Sequence[AdaptedAttentionLayers]):
Expand Down
Loading
Loading