From dde05b852aafdf4ce588d4c68c9b0108abbb57c0 Mon Sep 17 00:00:00 2001 From: Joseph Attieh Date: Mon, 10 Feb 2025 14:43:26 +0200 Subject: [PATCH] Test --- mammoth/distributed/components.py | 1 + mammoth/model_builder.py | 30 +++++++++++++++++++--------- mammoth/modules/x_tf.py | 33 +++++++++++++++++++++++++++++-- mammoth/opts.py | 7 ------- mammoth/utils/parse.py | 4 ++-- 5 files changed, 55 insertions(+), 20 deletions(-) diff --git a/mammoth/distributed/components.py b/mammoth/distributed/components.py index 5be8a0e2..86f1157e 100644 --- a/mammoth/distributed/components.py +++ b/mammoth/distributed/components.py @@ -135,6 +135,7 @@ def load_state_dict(self, model: NMTModel, state_dict: Dict[str, Any]): return mismatch._replace(missing_keys=missing_keys) + @dataclass # type: ignore class DistributedAttentionLayersBlock(DistributedComponent, ABC): """Represents a distributed AdaptedAttentionLayers object""" diff --git a/mammoth/model_builder.py b/mammoth/model_builder.py index d3760687..4b5700e3 100644 --- a/mammoth/model_builder.py +++ b/mammoth/model_builder.py @@ -10,6 +10,7 @@ from typing import Optional, List, Dict, Tuple from mammoth.modules.x_tf import TransformerWrapper from x_transformers.x_transformers import TokenEmbedding +from mammoth.constants import DefaultTokens from mammoth.distributed.components import ( DistributedAdapter, @@ -34,16 +35,17 @@ # embedding import torch.nn.functional as F class ByteEmbedding(Module): - def __init__(self, dim, num_tokens, l2norm_embed = False): + def __init__(self, dim, num_tokens, padding_idx): super().__init__() - self.emb = nn.Embedding(num_tokens, dim) + self.emb = nn.Embedding(num_tokens, dim, padding_idx=padding_idx) one_hot_matrix = F.one_hot(torch.arange(num_tokens)).float() one_hot_embed = torch.cat((one_hot_matrix, torch.zeros((num_tokens, dim - num_tokens))), dim=1) + one_hot_embed[padding_idx] = torch.zeros(dim).unsqueeze(0) self.emb.weight = torch.nn.parameter.Parameter(one_hot_embed, requires_grad=False) def forward(self, x): token_emb = self.emb(x.long()) return token_emb - + TRANSFORMER_WRAPPER_OPTS = { 'post_emb_norm', 'tie_embedding', @@ -273,12 +275,20 @@ def build_xcoder( for lang in all_langs: if lang not in token_embs: vocab = vocabs_dict[(side_alt_str, lang)] - Embedding = ByteEmbedding if model_opts.use_embeddingless else TokenEmbedding - token_embs[lang] = Embedding( - dim=model_opts.model_dim, - num_tokens=len(vocab), - l2norm_embed=l2norm_embed - ) + padding_idx = vocab[DefaultTokens.PAD] + if model_opts.use_embeddingless: + token_embs[lang] = ByteEmbedding( + dim=model_opts.model_dim, + num_tokens=len(vocab), + padding_idx=padding_idx + ) + else: + token_embs[lang] = TokenEmbedding( + dim=model_opts.model_dim, + num_tokens=len(vocab), + l2norm_embed=l2norm_embed + ) + # Create AdaptedAttentionLayersStack objects and TransformerWrapper objects tasks = task_queue_manager.get_my_tasks() if single_task: @@ -310,6 +320,8 @@ def build_xcoder( emb_dim=model_opts.model_dim, token_emb=token_embs[lang], initialize_embeddings=not (model_opts.use_embeddingless), + scale_outputs= model_opts.use_embeddingless, + scale_embeddings = model_opts.use_embeddingless, **transformer_wrapper_kwargs, ) transformer_wrappers[task.corpus_id] = transformer_wrapper diff --git a/mammoth/modules/x_tf.py b/mammoth/modules/x_tf.py index 126c6b64..253b4210 100644 --- a/mammoth/modules/x_tf.py +++ b/mammoth/modules/x_tf.py @@ -227,6 +227,14 @@ def forward(self, x): # embedding +class Scaler(Module): + def __init__(self,emb_dim): + super().__init__() + self.embed_scale = nn.Parameter(torch.Tensor([0])) + nn.init.constant_(self.embed_scale, math.sqrt(emb_dim)) + def forward(self, x): + return x*self.embed_scale + class TokenEmbedding(Module): def __init__(self, dim, num_tokens, l2norm_embed = False): super().__init__() @@ -2008,7 +2016,9 @@ def __init__( mixture_of_softmax = False, mixture_of_softmax_k = 4, sigsoftmax_logits = False, - initialize_embeddings=True + initialize_embeddings=True, + scale_embeddings: bool = False, # NEW: whether to multiply embeddings by sqrt(emb_dim) + scale_outputs: bool = False, # NEW: whether to scale final outputs ): super().__init__() self.initialize_embeddings =initialize_embeddings @@ -2017,6 +2027,9 @@ def __init__( self.emb_dim = emb_dim self.num_tokens = num_tokens + self.scale_embeddings = scale_embeddings + self.scale_outputs = scale_outputs + self.max_seq_len = max_seq_len self.max_mem_len = max_mem_len self.shift_mem_down = shift_mem_down @@ -2027,7 +2040,10 @@ def __init__( token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed) self.token_emb = token_emb - + + self.embed_scale = Scaler(emb_dim) + self.out_embed_scale = Scaler(emb_dim) + no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb) if no_abs_pos_emb: @@ -2177,6 +2193,9 @@ def forward( pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos x = self.token_emb(x) + pos_emb + if self.scale_embeddings: + x = self.embed_scale(x) + # add additional embeddings assert not (exists(self.embeds) ^ (len(embed_ids) > 0)), '`embed_num_tokens` must be defined on `TransformerWrapper`' @@ -2356,6 +2375,16 @@ def forward( # different returns + if self.scale_outputs: + if return_logits_and_embeddings: + x = self.out_embed_scale(x) + logits = self.out_embed_scale(logits) + elif return_embeddings: + x = self.out_embed_scale(x) + else: + logits = self.out_embed_scale(logits) + + if return_logits_and_embeddings: out = (logits, x) elif return_embeddings: diff --git a/mammoth/opts.py b/mammoth/opts.py index 520c0602..45174332 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -204,13 +204,6 @@ def model_opts(parser): # Embedding Options group = parser.add_argument_group('Model- Embeddings') - group.add( - '--enable_embeddingless', - '-enable_embeddingless', - action='store_true', - help="Enable the use of byte-based embeddingless models" + - "(Shaham et. al, 2021) https://aclanthology.org/2021.naacl-main.17/", - ) # Encoder-Decoder Options group = parser.add_argument_group('Model- Encoder-Decoder') diff --git a/mammoth/utils/parse.py b/mammoth/utils/parse.py index 38f4690b..7a9da8f8 100644 --- a/mammoth/utils/parse.py +++ b/mammoth/utils/parse.py @@ -102,8 +102,8 @@ def _validate_tasks(cls, opts): " We default it to 0 (start of training) for you." ) corpus['introduce_at_training_step'] = 0 - enable_embeddingless = corpus.get('enable_embeddingless', False) - opts.enable_embeddingless = enable_embeddingless + use_embeddingless = corpus.get('use_embeddingless', False) + opts.use_embeddingless = use_embeddingless # Check sharing groups enc_sharing_group = corpus.get('enc_sharing_group', None) assert enc_sharing_group is None or isinstance(enc_sharing_group, list)