diff --git a/examples/enwik8_simple/train.py b/examples/enwik8_simple/train.py index 339e4229..8f702ee3 100644 --- a/examples/enwik8_simple/train.py +++ b/examples/enwik8_simple/train.py @@ -39,7 +39,12 @@ def decode_tokens(tokens): model = TransformerWrapper( num_tokens = 256, max_seq_len = SEQ_LEN, - attn_layers = Decoder(dim = 512, depth = 6, heads = 8) + attn_layers = Decoder( + dim = 512, + depth = 6, + heads = 8, + rotary_pos_emb = True + ) ) model = AutoregressiveWrapper(model) @@ -101,6 +106,11 @@ def __len__(self): prime = decode_tokens(inp) print(f'%s \n\n %s', (prime, '*' * 100)) - sample = model.generate(inp, GENERATE_LENGTH) + sample = model.generate( + prompts = inp, + seq_len = GENERATE_LENGTH, + cache_kv = True + ) + output_str = decode_tokens(sample) print(output_str) diff --git a/setup.py b/setup.py index 80ccd18a..bfcd3b5b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.26.4', + version = '1.26.6', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/autoregressive_wrapper.py b/x_transformers/autoregressive_wrapper.py index 8bb4d85a..41bac84c 100644 --- a/x_transformers/autoregressive_wrapper.py +++ b/x_transformers/autoregressive_wrapper.py @@ -188,6 +188,10 @@ def generate( for _ in range(seq_len): if restrict_to_max_seq_len: + max_len_exceeded = out.shape[-1] > max_seq_len + + assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embeeding. you can switch to rotary embeddings to resolve this issue' + x = out[:, -max_seq_len:] if exists(cache): diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index e000198d..981a3dff 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -1495,7 +1495,9 @@ def __init__( self.l2norm_embed = l2norm_embed self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed) - if max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.has_pos_emb): + no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.has_pos_emb) + + if no_abs_pos_emb: self.pos_emb = always(0) elif scaled_sinu_pos_emb: self.pos_emb = ScaledSinusoidalEmbedding(emb_dim) @@ -1536,6 +1538,7 @@ def __init__( # whether can do cached kv decoding self.can_cache_kv = self.num_memory_tokens == 0 + self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb def init_(self): if self.l2norm_embed: