Skip to content

Commit

Permalink
Test
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Attieh committed Feb 10, 2025
1 parent 284dda8 commit dde05b8
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 20 deletions.
1 change: 1 addition & 0 deletions mammoth/distributed/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def load_state_dict(self, model: NMTModel, state_dict: Dict[str, Any]):
return mismatch._replace(missing_keys=missing_keys)



@dataclass # type: ignore
class DistributedAttentionLayersBlock(DistributedComponent, ABC):
"""Represents a distributed AdaptedAttentionLayers object"""
Expand Down
30 changes: 21 additions & 9 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Optional, List, Dict, Tuple
from mammoth.modules.x_tf import TransformerWrapper
from x_transformers.x_transformers import TokenEmbedding
from mammoth.constants import DefaultTokens

from mammoth.distributed.components import (
DistributedAdapter,
Expand All @@ -34,16 +35,17 @@
# embedding
import torch.nn.functional as F
class ByteEmbedding(Module):
def __init__(self, dim, num_tokens, l2norm_embed = False):
def __init__(self, dim, num_tokens, padding_idx):
super().__init__()
self.emb = nn.Embedding(num_tokens, dim)
self.emb = nn.Embedding(num_tokens, dim, padding_idx=padding_idx)
one_hot_matrix = F.one_hot(torch.arange(num_tokens)).float()
one_hot_embed = torch.cat((one_hot_matrix, torch.zeros((num_tokens, dim - num_tokens))), dim=1)
one_hot_embed[padding_idx] = torch.zeros(dim).unsqueeze(0)
self.emb.weight = torch.nn.parameter.Parameter(one_hot_embed, requires_grad=False)
def forward(self, x):
token_emb = self.emb(x.long())
return token_emb

TRANSFORMER_WRAPPER_OPTS = {
'post_emb_norm',
'tie_embedding',
Expand Down Expand Up @@ -273,12 +275,20 @@ def build_xcoder(
for lang in all_langs:
if lang not in token_embs:
vocab = vocabs_dict[(side_alt_str, lang)]
Embedding = ByteEmbedding if model_opts.use_embeddingless else TokenEmbedding
token_embs[lang] = Embedding(
dim=model_opts.model_dim,
num_tokens=len(vocab),
l2norm_embed=l2norm_embed
)
padding_idx = vocab[DefaultTokens.PAD]
if model_opts.use_embeddingless:
token_embs[lang] = ByteEmbedding(
dim=model_opts.model_dim,
num_tokens=len(vocab),
padding_idx=padding_idx
)
else:
token_embs[lang] = TokenEmbedding(
dim=model_opts.model_dim,
num_tokens=len(vocab),
l2norm_embed=l2norm_embed
)

# Create AdaptedAttentionLayersStack objects and TransformerWrapper objects
tasks = task_queue_manager.get_my_tasks()
if single_task:
Expand Down Expand Up @@ -310,6 +320,8 @@ def build_xcoder(
emb_dim=model_opts.model_dim,
token_emb=token_embs[lang],
initialize_embeddings=not (model_opts.use_embeddingless),
scale_outputs= model_opts.use_embeddingless,
scale_embeddings = model_opts.use_embeddingless,
**transformer_wrapper_kwargs,
)
transformer_wrappers[task.corpus_id] = transformer_wrapper
Expand Down
33 changes: 31 additions & 2 deletions mammoth/modules/x_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,14 @@ def forward(self, x):

# embedding

class Scaler(Module):
def __init__(self,emb_dim):
super().__init__()
self.embed_scale = nn.Parameter(torch.Tensor([0]))
nn.init.constant_(self.embed_scale, math.sqrt(emb_dim))
def forward(self, x):
return x*self.embed_scale

class TokenEmbedding(Module):
def __init__(self, dim, num_tokens, l2norm_embed = False):
super().__init__()
Expand Down Expand Up @@ -2008,7 +2016,9 @@ def __init__(
mixture_of_softmax = False,
mixture_of_softmax_k = 4,
sigsoftmax_logits = False,
initialize_embeddings=True
initialize_embeddings=True,
scale_embeddings: bool = False, # NEW: whether to multiply embeddings by sqrt(emb_dim)
scale_outputs: bool = False, # NEW: whether to scale final outputs
):
super().__init__()
self.initialize_embeddings =initialize_embeddings
Expand All @@ -2017,6 +2027,9 @@ def __init__(
self.emb_dim = emb_dim
self.num_tokens = num_tokens

self.scale_embeddings = scale_embeddings
self.scale_outputs = scale_outputs

self.max_seq_len = max_seq_len
self.max_mem_len = max_mem_len
self.shift_mem_down = shift_mem_down
Expand All @@ -2027,7 +2040,10 @@ def __init__(
token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)

self.token_emb = token_emb


self.embed_scale = Scaler(emb_dim)
self.out_embed_scale = Scaler(emb_dim)

no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb)

if no_abs_pos_emb:
Expand Down Expand Up @@ -2177,6 +2193,9 @@ def forward(
pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
x = self.token_emb(x) + pos_emb

if self.scale_embeddings:
x = self.embed_scale(x)

# add additional embeddings

assert not (exists(self.embeds) ^ (len(embed_ids) > 0)), '`embed_num_tokens` must be defined on `TransformerWrapper`'
Expand Down Expand Up @@ -2356,6 +2375,16 @@ def forward(

# different returns

if self.scale_outputs:
if return_logits_and_embeddings:
x = self.out_embed_scale(x)
logits = self.out_embed_scale(logits)
elif return_embeddings:
x = self.out_embed_scale(x)
else:
logits = self.out_embed_scale(logits)


if return_logits_and_embeddings:
out = (logits, x)
elif return_embeddings:
Expand Down
7 changes: 0 additions & 7 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,6 @@ def model_opts(parser):
# Embedding Options
group = parser.add_argument_group('Model- Embeddings')

group.add(
'--enable_embeddingless',
'-enable_embeddingless',
action='store_true',
help="Enable the use of byte-based embeddingless models" +
"(Shaham et. al, 2021) https://aclanthology.org/2021.naacl-main.17/",
)

# Encoder-Decoder Options
group = parser.add_argument_group('Model- Encoder-Decoder')
Expand Down
4 changes: 2 additions & 2 deletions mammoth/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def _validate_tasks(cls, opts):
" We default it to 0 (start of training) for you."
)
corpus['introduce_at_training_step'] = 0
enable_embeddingless = corpus.get('enable_embeddingless', False)
opts.enable_embeddingless = enable_embeddingless
use_embeddingless = corpus.get('use_embeddingless', False)
opts.use_embeddingless = use_embeddingless
# Check sharing groups
enc_sharing_group = corpus.get('enc_sharing_group', None)
assert enc_sharing_group is None or isinstance(enc_sharing_group, list)
Expand Down

0 comments on commit dde05b8

Please sign in to comment.