Skip to content

Commit

Permalink
fix unpacking error (#507)
Browse files Browse the repository at this point in the history
key_value_memory_dict in MHA module is a tuple of kv_cache and conv1d
  • Loading branch information
chiennv2000 authored Aug 6, 2024
1 parent 49ddf83 commit a07ff1b
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mamba_ssm/modules/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def _update_kvcache_attention(self, q, kv, inference_params):
).transpose(1, 2)
else:
batch = q.shape[0]
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
kv_cache = kv_cache[:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
Expand Down

0 comments on commit a07ff1b

Please sign in to comment.