diff --git a/README.md b/README.md index 252f584c..c315eafa 100644 --- a/README.md +++ b/README.md @@ -2015,16 +2015,6 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17) } ``` -```bibtex -@article{Xie2023ResiDualTW, - title = {ResiDual: Transformer with Dual Residual Connections}, - author = {Shufang Xie and Huishuai Zhang and Junliang Guo and Xu Tan and Jiang Bian and Hany Hassan Awadalla and Arul Menezes and Tao Qin and Rui Yan}, - journal = {ArXiv}, - year = {2023}, - volume = {abs/2304.14802} -} -``` - ```bibtex @inproceedings{Dehghani2023ScalingVT, title = {Scaling Vision Transformers to 22 Billion Parameters}, diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index cdfff25c..a4292414 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -9,7 +9,7 @@ import torch from torch.amp import autocast import torch.nn.functional as F -from torch import nn, einsum, Tensor +from torch import nn, einsum, Tensor, cat, stack, arange from torch.utils._pytree import tree_flatten, tree_unflatten from torch.nn import Module, ModuleList, ModuleDict @@ -18,14 +18,22 @@ from contextlib import nullcontext from dataclasses import dataclass +from loguru import logger + +from x_transformers.attend import Attend, Intermediates +from x_transformers.autoregressive_wrapper import AutoregressiveWrapper + import einx from einops.layers.torch import Rearrange from einops import rearrange, repeat, reduce, pack, unpack -from loguru import logger +# einstein notation -from x_transformers.attend import Attend, Intermediates -from x_transformers.autoregressive_wrapper import AutoregressiveWrapper +# b - batch +# n - sequence +# d - feature dimension +# h - attention heads +# i, j - sequence (source, target) # constants @@ -220,7 +228,7 @@ def dropout_seq(seq, mask, dropout): num_keep = max(1, int(keep_prob * n)) keep_indices = logits.topk(num_keep, dim = 1).indices - batch_indices = torch.arange(b, device = device) + batch_indices = arange(b, device = device) batch_indices = rearrange(batch_indices, 'b -> b 1') seq = seq[batch_indices, keep_indices] @@ -228,7 +236,7 @@ def dropout_seq(seq, mask, dropout): if exists(mask): seq_counts = mask.sum(dim = -1) seq_keep_counts = torch.ceil(seq_counts * keep_prob).int() - keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1') + keep_mask = arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1') mask = mask[batch_indices, keep_indices] & keep_mask @@ -274,7 +282,7 @@ def forward(self, x, pos = None, seq_start_pos = None): assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' if not exists(pos): - pos = torch.arange(seq_len, device = device) + pos = arange(seq_len, device = device) if exists(seq_start_pos): pos = (pos - seq_start_pos[..., None]).clamp(min = 0) @@ -290,7 +298,7 @@ def __init__(self, dim, theta = 10000): self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) half_dim = dim // 2 - freq_seq = torch.arange(half_dim).float() / half_dim + freq_seq = arange(half_dim).float() / half_dim inv_freq = theta ** -freq_seq self.register_buffer('inv_freq', inv_freq, persistent = False) @@ -298,13 +306,13 @@ def forward(self, x, pos = None, seq_start_pos = None): seq_len, device = x.shape[1], x.device if not exists(pos): - pos = torch.arange(seq_len, device = device) + pos = arange(seq_len, device = device) if exists(seq_start_pos): pos = pos - seq_start_pos[..., None] emb = einsum('i, j -> i j', pos, self.inv_freq) - emb = torch.cat((emb.sin(), emb.cos()), dim = -1) + emb = cat((emb.sin(), emb.cos()), dim = -1) return emb * self.scale class RelativePositionBias(Module): @@ -344,8 +352,8 @@ def device(self): def forward(self, i, j): device = self.device - q_pos = torch.arange(j - i, j, dtype = torch.long, device = device) - k_pos = torch.arange(j, dtype = torch.long, device = device) + q_pos = arange(j - i, j, dtype = torch.long, device = device) + k_pos = arange(j, dtype = torch.long, device = device) rel_pos = einx.subtract('j, i -> i j', k_pos, q_pos) rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance) values = self.relative_attention_bias(rp_bucket) @@ -376,7 +384,7 @@ def __init__ ( if not soft_onehot: return - self.register_buffer('positions', torch.arange(max_pos)) + self.register_buffer('positions', arange(max_pos)) def forward(self, query, attn_logits): @@ -445,13 +453,13 @@ def forward(self, i, j): n, device = j, self.device # get the (n x n) matrix of distances - seq_arange = torch.arange(n, device = device) - context_arange = torch.arange(n, device = device) + seq_arange = arange(n, device = device) + context_arange = arange(n, device = device) indices = einx.subtract('i, j -> i j', seq_arange, context_arange) indices += (n - 1) # input to continuous positions MLP - pos = torch.arange(-n + 1, n, device = device).float() + pos = arange(-n + 1, n, device = device).float() pos = rearrange(pos, '... -> ... 1') if self.log_distance: @@ -525,8 +533,8 @@ def forward(self, i, j): if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i: return self.bias[..., -i:, -j:] - seq_arange = torch.arange(j - i, j, device = device) - context_arange = torch.arange(j, device = device) + seq_arange = arange(j - i, j, device = device) + context_arange = arange(j, device = device) bias = -einx.subtract('j, i -> 1 i j', context_arange, seq_arange).abs() bias = bias * self.slopes @@ -642,7 +650,7 @@ def __init__( # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ base *= base_rescale_factor ** (dim / (dim - 2)) - inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = 1. / (base ** (arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) assert interpolation_factor >= 1. @@ -652,7 +660,7 @@ def __init__( self.register_buffer('scale', None) return - scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + scale = (arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) self.scale_base = scale_base self.register_buffer('scale', scale) @@ -660,7 +668,7 @@ def __init__( def forward_from_seq_len(self, seq_len): device = self.inv_freq.device - t = torch.arange(seq_len, device = device) + t = arange(seq_len, device = device) return self.forward(t) @autocast('cuda', enabled = False) @@ -671,7 +679,7 @@ def forward(self, t): t = rearrange(t, 'n -> 1 n') freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor - freqs = torch.stack((freqs, freqs), dim = -1) + freqs = stack((freqs, freqs), dim = -1) freqs = rearrange(freqs, '... d r -> ... (d r)') if not exists(self.scale): @@ -679,7 +687,7 @@ def forward(self, t): power = (t - (max_pos // 2)) / self.scale_base scale = self.scale ** rearrange(power, '... n -> ... n 1') - scale = torch.stack((scale, scale), dim = -1) + scale = stack((scale, scale), dim = -1) scale = rearrange(scale, '... d r -> ... (d r)') return freqs, scale @@ -687,7 +695,7 @@ def forward(self, t): def rotate_half(x): x = rearrange(x, '... (d r) -> ... d r', r = 2) x1, x2 = x.unbind(dim = -1) - x = torch.stack((-x2, x1), dim = -1) + x = stack((-x2, x1), dim = -1) return rearrange(x, '... d r -> ... (d r)') @autocast('cuda', enabled = False) @@ -703,7 +711,7 @@ def apply_rotary_pos_emb(t, freqs, scale = 1): # partial rotary embeddings, Wang et al. GPT-J t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) - out = torch.cat((t, t_unrotated), dim = -1) + out = cat((t, t_unrotated), dim = -1) return out.type(orig_dtype) @@ -904,7 +912,7 @@ def __init__( init_alpha0 = torch.zeros((num_residual_streams, num_input_views)) init_alpha0[layer_index % num_residual_streams, :] = 1. - self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1)) + self.static_alpha = nn.Parameter(cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1)) self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views)) self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2) @@ -973,7 +981,7 @@ def forward(self, x, **kwargs): splitted = x.split(feats_per_shift, dim = -1) segments_to_shift, rest = splitted[:segments], splitted[segments:] segments_to_shift = [shift(*args, mask = mask) for args in zip(segments_to_shift, shifts)] - x = torch.cat((*segments_to_shift, *rest), dim = -1) + x = cat((*segments_to_shift, *rest), dim = -1) return self.fn(x, **kwargs) class FoldAxially(Module): @@ -1080,7 +1088,7 @@ def __init__(self, dim, prev_layer_ind): def forward(self, x, prev_layers: list[Tensor]): skip = prev_layers[self.prev_layer_ind] - concatted_skip = torch.cat((skip, x), dim = -1) + concatted_skip = cat((skip, x), dim = -1) return self.combine(concatted_skip) # feedforward @@ -1476,7 +1484,9 @@ def forward( if self.use_latent_kv: assert not qkv_receive_diff_residuals - v_input = k_input = self.to_latent_kv(k_input) + + latent_kv_input = self.to_latent_kv(k_input) + k_input = v_input = latent_kv_input # query, key, value projection @@ -1516,12 +1526,12 @@ def forward( mk, k = unpack(k, mem_packed_shape, 'b h * d') mv, v = unpack(v, mem_packed_shape, 'b h * d') - k = torch.cat((ck, k), dim = -2) - v = torch.cat((cv, v), dim = -2) + k = cat((ck, k), dim = -2) + v = cat((cv, v), dim = -2) if exists(mem): - k = torch.cat((mk, k), dim = -2) - v = torch.cat((mv, v), dim = -2) + k = cat((mk, k), dim = -2) + v = cat((mv, v), dim = -2) if return_intermediates: mem_len = mem.shape[-2] if exists(mem) else 0 @@ -1557,7 +1567,7 @@ def forward( elif not exists(input_mask): input_mask = pad_at_dim(mem_mask, (0, seq_len), dim = -1, value = True) else: - input_mask = torch.cat((mem_mask, input_mask), dim = -1) + input_mask = cat((mem_mask, input_mask), dim = -1) # i, j determined for relative positional bias, excluding memory key / values @@ -1572,8 +1582,8 @@ def forward( mem_k = l2norm(mem_k) mem_k = mem_k * self.qk_norm_k_scale - k = torch.cat((mem_k, k), dim = -2) - v = torch.cat((mem_v, v), dim = -2) + k = cat((mem_k, k), dim = -2) + v = cat((mem_v, v), dim = -2) if exists(input_mask): input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True) @@ -1597,8 +1607,8 @@ def forward( masks.append(~attn_mask) if exists(self.max_attend_past): - range_q = torch.arange(j - i, j, device = device) - range_k = torch.arange(j, device = device) + range_q = arange(j - i, j, device = device) + range_k = arange(j, device = device) dist = einx.subtract('i, j -> 1 1 i j', range_q, range_k) max_attend_past_mask = dist > self.max_attend_past max_attend_past_mask = pad_at_dim(max_attend_past_mask, (num_mem_kv, 0), value = False, dim = -1) # handle memory key / values @@ -1756,8 +1766,6 @@ def __init__( sandwich_norm = False, softclamp_output = False, softclamp_output_value = 30., - resi_dual = False, - resi_dual_scale = 1., zero_init_branch_output = False, layer_dropout = 0., cross_attn_tokens_dropout = 0., @@ -1837,19 +1845,11 @@ def __init__( assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads' self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads, **rel_pos_kwargs) - assert at_most_one_of(sandwich_norm, resi_dual), 'either sandwich norm or resiDual is selected, but not both' assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm' - if resi_dual: - pre_norm = False - self.pre_norm = pre_norm self.sandwich_norm = sandwich_norm - self.resi_dual = resi_dual - assert 0 < resi_dual_scale <= 1., 'resiDual prenorm residual must be scaled by a factor greater than 0 and less than or equal to 1.' - self.resi_dual_scale = resi_dual_scale - self.residual_attn = residual_attn self.cross_residual_attn = cross_residual_attn assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention' @@ -2008,7 +2008,7 @@ def __init__( # whether it has post norm - self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity() + self.final_norm = norm_fn() if pre_norm else nn.Identity() # whether unet or not @@ -2181,7 +2181,7 @@ def forward( # handle left padded sequences if exists(seq_start_pos): - seq_arange = torch.arange(x.shape[-2], device = x.device, dtype = torch.long) + seq_arange = arange(x.shape[-2], device = x.device, dtype = torch.long) left_pad_mask = seq_arange >= seq_start_pos[..., None] if exists(self_attn_kv_mask): @@ -2199,7 +2199,7 @@ def forward( mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0 if not exists(pos): - pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len + pos = arange(x.shape[1] + mem_len, device = x.device) - mem_len rotary_pos_emb = self.rotary_pos_emb(pos) @@ -2241,10 +2241,6 @@ def forward( x = x + self.stream_emb x = rearrange(x, 'b n s d -> (b s) n d') - # outer residual - for resiDual paper - - outer_residual = x * self.resi_dual_scale - # get layers to be executed layer_variables = ( @@ -2365,9 +2361,6 @@ def forward( if not exists(first_cross_attn_inter) and layer_type == 'c': first_cross_attn_inter = inter - if self.resi_dual: - outer_residual = outer_residual + out * self.resi_dual_scale - if exists(post_branch_norm): out = post_branch_norm(out) @@ -2401,10 +2394,7 @@ def forward( if is_multistream: x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams) - if self.resi_dual: - x = x + final_norm(outer_residual) - else: - x = final_norm(x) + x = final_norm(x) if not return_hiddens: return x @@ -2450,7 +2440,7 @@ def forward( if isinstance(prefix_attn_len, int): prefix_attn_len = torch.full((b,), prefix_attn_len, device = device) - prefix_mask = torch.arange(n, device = device) < rearrange(prefix_attn_len, 'b -> b 1 1 1') + prefix_mask = arange(n, device = device) < rearrange(prefix_attn_len, 'b -> b 1 1 1') forwarded_mask = forwarded_mask | prefix_mask if exists(attn_mask): @@ -2779,13 +2769,13 @@ def forward( prepend_seq, prepend_dim = prepend_embeds.shape[1:] assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions' - x = torch.cat((prepend_embeds, x), dim = -2) + x = cat((prepend_embeds, x), dim = -2) if exists(prepend_mask) or exists(mask): mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool)) prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool)) - mask = torch.cat((prepend_mask, mask), dim = -1) + mask = cat((prepend_mask, mask), dim = -1) # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model @@ -2951,7 +2941,7 @@ def forward( if return_mems: hiddens = intermediates.hiddens - new_mems = [torch.cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens + new_mems = [cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems] if not return_intermediates: