From 03a38fb27e969750074e3cabc451f5f85fea2b17 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 30 Jun 2024 17:32:13 -0700 Subject: [PATCH] Implement varlen generation --- README.md | 2 +- mamba_ssm/models/mixer_seq_simple.py | 2 +- mamba_ssm/modules/mamba2.py | 46 ++++++--- mamba_ssm/ops/triton/ssd_chunk_state.py | 120 ++++++++++++++++++++++++ mamba_ssm/ops/triton/ssd_combined.py | 32 +++++-- setup.py | 2 +- tests/ops/triton/test_ssd.py | 78 +++++++++++++++ tests/test_generation.py | 113 ++++++++++++++++++++++ 8 files changed, 372 insertions(+), 23 deletions(-) create mode 100644 tests/ops/triton/test_ssd.py create mode 100644 tests/test_generation.py diff --git a/README.md b/README.md index d880bf32..450e8730 100755 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ with an efficient hardware-aware design and implementation in the spirit of [Fla ## Installation -- [Option] `pip install causal-conv1d>=1.2.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block. +- [Option] `pip install causal-conv1d>=1.4.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block. - `pip install mamba-ssm`: the core Mamba package. It can also be built from source with `pip install .` from this repository. diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index 4be57e08..fae2257a 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -192,7 +192,7 @@ def forward(self, input_ids, inference_params=None, **mixer_kwargs): residual = None for layer in self.layers: hidden_states, residual = layer( - hidden_states, residual, inference_params=inference_params + hidden_states, residual, inference_params=inference_params, **mixer_kwargs ) if not self.fused_add_norm: residual = (hidden_states + residual) if residual is not None else hidden_states diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 15732f88..854ad0a8 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -13,6 +13,11 @@ except ImportError: causal_conv1d_fn, causal_conv1d_update = None, None +try: + from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states +except ImportError: + causal_conv1d_varlen_states = None + try: from mamba_ssm.ops.triton.selective_state_update import selective_state_update except ImportError: @@ -144,7 +149,7 @@ def __init__( process_group=self.process_group, sequence_parallel=self.sequence_parallel, **factory_kwargs) - def forward(self, u, seqlen=None, seq_idx=None, inference_params=None): + def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None): """ u: (batch, seqlen, hidden_dim) if seqlen=None. If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we @@ -161,7 +166,8 @@ def forward(self, u, seqlen=None, seq_idx=None, inference_params=None): conv_state, ssm_state = None, None if inference_params is not None: - conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) + inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch + conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch) if inference_params.seqlen_offset > 0: # The states are updated inplace out, _, _ = self.step(u, conv_state, ssm_state) @@ -206,14 +212,22 @@ def forward(self, u, seqlen=None, seq_idx=None, inference_params=None): dim=-1 ) if conv_state is not None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = rearrange(xBC, "b l d -> b d l") - conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + if cu_seqlens is None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC, "b l d -> b d l") + conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + else: + assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package" + assert batch == 1, "varlen inference only supports batch dimension 1" + conv_varlen_states = causal_conv1d_varlen_states( + xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1] + ) + conv_state.copy_(conv_varlen_states) assert self.activation in ["silu", "swish"] if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: xBC = self.act( - self.conv1d(xBC.transpose(1, 2)).transpose(1, 2) + self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):] ) # (B, L, self.d_ssm + 2 * ngroups * d_state) else: xBC = causal_conv1d_fn( @@ -235,12 +249,18 @@ def forward(self, u, seqlen=None, seq_idx=None, inference_params=None): dt_bias=self.dt_bias, dt_softplus=True, seq_idx=seq_idx, + cu_seqlens=cu_seqlens, **dt_limit_kwargs, return_final_states=ssm_state is not None, + return_varlen_states=cu_seqlens is not None and inference_params is not None, ) if ssm_state is not None: - y, last_state = y - ssm_state.copy_(last_state) + y, last_state, *rest = y + if cu_seqlens is None: + ssm_state.copy_(last_state) + else: + varlen_states = rest[0] + ssm_state.copy_(varlen_states) y = rearrange(y, "b l h p -> b l (h p)") if self.rmsnorm: y = self.norm(y, z) @@ -322,8 +342,8 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs) device = self.out_proj.weight.device conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype conv_state = torch.zeros( - batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=conv_dtype - ) + batch_size, self.d_conv, self.conv1d.weight.shape[0], device=device, dtype=conv_dtype + ).transpose(1, 2) ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype ssm_state = torch.zeros( batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype @@ -336,11 +356,11 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states batch_shape = (batch_size,) conv_state = torch.zeros( batch_size, - self.conv1d.weight.shape[0], self.d_conv, + self.conv1d.weight.shape[0], device=self.conv1d.weight.device, dtype=self.conv1d.weight.dtype, - ) + ).transpose(1, 2) ssm_state = torch.zeros( batch_size, self.nheads, diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index 0c23f327..c4971c5f 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -571,6 +571,97 @@ def _chunk_state_bwd_ddAcs_stable_kernel( tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_state_varlen_kernel( + # Pointers to matrices + x_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, chunk_states_ptr, cu_seqlens_ptr, states_ptr, + # Matrix dimensions + hdim, dstate, chunk_size, + seqlen, nheads_ngroups_ratio, + # Strides + stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_chunk_states_chunk, stride_chunk_states_head, stride_chunk_states_hdim, stride_chunk_states_dstate, + stride_states_batch, stride_states_head, stride_states_hdim, stride_states_dstate, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) + pid_c = (end_idx - 1) // chunk_size + b_ptr += pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + chunk_size_limit = end_idx - pid_c * chunk_size + start_idx = tl.load(cu_seqlens_ptr + pid_b) + start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k) & (offs_k[None, :] >= start_idx_cur - k), other=0.0) + b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), + tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk + if start_idx < pid_c * chunk_size: + chunk_states_ptrs = chunk_states_ptr + (offs_m[:, None] * stride_chunk_states_hdim + offs_n[None, :] * stride_chunk_states_dstate) + chunk_states = tl.load(chunk_states_ptrs, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0) + scale = tl.exp(dA_cs_last) + acc += chunk_states * scale + + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): batch, seqlen, nheads = dt.shape assert A.shape == (nheads,) @@ -790,6 +881,35 @@ def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None): return ddA_cumsum +def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states): + total_seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + assert nheads % ngroups == 0 + assert B.shape == (total_seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert chunk_states.shape == (nchunks, nheads, headdim, dstate) + states = torch.empty(batch, nheads, headdim, dstate, dtype=chunk_states.dtype, device=chunk_states.device) + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_varlen_kernel[grid]( + x, B, dt, dA_cumsum, chunk_states, cu_seqlens, states, + headdim, dstate, chunk_size, + total_seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), + B.stride(0), B.stride(1), B.stride(2), + dt.stride(1), dt.stride(0), dt.stride(2), + dA_cumsum.stride(1), dA_cumsum.stride(0), dA_cumsum.stride(2), + chunk_states.stride(0), chunk_states.stride(1), chunk_states.stride(2), chunk_states.stride(3), + states.stride(0), states.stride(1), states.stride(2), states.stride(3), + ) + return states + + class ChunkStateFn(torch.autograd.Function): @staticmethod diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index df218a73..1305cfb4 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -29,6 +29,7 @@ from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref +from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state_varlen from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates @@ -277,7 +278,7 @@ def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=Non return dx, ddt.to(dtype=dt.dtype), dD -def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): +def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape assert nheads % ngroups == 0 @@ -321,7 +322,13 @@ def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, d # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx) - return out, out_x, dt, dA_cumsum, states, final_states + if cu_seqlens is None: + return out, out_x, dt, dA_cumsum, states, final_states + else: + assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" + varlen_states = chunk_state_varlen(B.squeeze(0), x.squeeze(0), dt.squeeze(0), dA_cumsum.squeeze(0), + cu_seqlens, states.squeeze(0)) + return out, out_x, dt, dA_cumsum, states, final_states, varlen_states def _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None, @@ -524,25 +531,35 @@ def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None): class MambaChunkScanCombinedFn(torch.autograd.Function): @staticmethod - def forward(ctx, x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False): + def forward(ctx, x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False): ctx.dt_dtype = dt.dtype - out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit) + if not return_varlen_states: + cu_seqlens = None + else: + assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" + out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) ctx.save_for_backward(out if z is None else out_x, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx) ctx.dt_softplus = dt_softplus ctx.chunk_size = chunk_size ctx.dt_limit = dt_limit ctx.return_final_states = return_final_states - return out if not return_final_states else (out, final_states) + ctx.return_varlen_states = return_varlen_states + if not return_varlen_states: + return out if not return_final_states else (out, final_states) + else: + varlen_states = rest[0] + return (out, varlen_states) if not return_final_states else (out, final_states, varlen_states) @staticmethod def backward(ctx, dout, *args): out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors + assert not ctx.return_varlen_states, "return_varlen_states is not supported in backward" dfinal_states = args[0] if ctx.return_final_states else None dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit) return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None -def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False): +def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False): """ Argument: x: (batch, seqlen, nheads, headdim) @@ -556,11 +573,12 @@ def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bia dt_bias: (nheads,) initial_states: (batch, nheads, headdim, dstate) seq_idx: (batch, seqlen) + cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True dt_softplus: Whether to apply softplus to dt Return: out: (batch, seqlen, nheads, headdim) """ - return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states) + return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit, return_final_states, return_varlen_states) def mamba_chunk_scan(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False): diff --git a/setup.py b/setup.py index 1f718bab..58772cb8 100755 --- a/setup.py +++ b/setup.py @@ -375,6 +375,6 @@ def run(self): "einops", "triton", "transformers", - # "causal_conv1d>=1.2.0", + # "causal_conv1d>=1.4.0", ], ) diff --git a/tests/ops/triton/test_ssd.py b/tests/ops/triton/test_ssd.py new file mode 100644 index 00000000..d45152d6 --- /dev/null +++ b/tests/ops/triton/test_ssd.py @@ -0,0 +1,78 @@ +import math + +import torch +import torch.nn.functional as F + +import pytest + +from einops import rearrange, repeat + +from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref +from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd +from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state_varlen +from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref +from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd +from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_chunk_scan, ssd_chunk_scan_combined_ref, ssd_selective_scan +from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined, mamba_split_conv1d_scan_ref + + +def detach_clone(*args): + return tuple([arg.detach().clone().requires_grad_() if arg is not None else None for arg in args]) + + +@pytest.mark.parametrize('dtype', [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize('dtype', [torch.bfloat16]) +@pytest.mark.parametrize('ngroups', [1, 2, 8, "max"]) +# @pytest.mark.parametrize('ngroups', [1]) +@pytest.mark.parametrize('chunk_size', [64, 128]) +# @pytest.mark.parametrize('chunk_size', [128]) +def test_chunk_state_varlen(chunk_size, ngroups, dtype): + device = 'cuda' + rtol, atol = (1e-2, 3e-3) + # set seed + torch.random.manual_seed(chunk_size + (ngroups if ngroups != "max" else 64)) + batch = 300 + seqlens = torch.randint(1, 200, (batch,), device=device) + # batch = 3 + # seqlens = torch.tensor([201, 56, 5], device=device) + cu_seqlens = F.pad(seqlens.cumsum(0), (1, 0)) + total_seqlen = seqlens.sum().item() + seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(seqlens)], dim=0).unsqueeze(0) + dim = 4096 + # dim = 64 + headdim = 64 + # dim = 32 + dstate = 32 + assert dim % headdim == 0 + nheads = dim // headdim + if ngroups == "max": + ngroups = nheads + assert nheads % ngroups == 0 + B = torch.randn(total_seqlen, ngroups, dstate, dtype=dtype, device=device) / 5 + x = torch.randn(total_seqlen, nheads, headdim, dtype=dtype, device=device) + A = -0.1 * (torch.rand(nheads, device=device)) + dt = F.softplus(torch.randn(total_seqlen, nheads, device=device, dtype=torch.float32) - 4) + dA_cumsum, dt_rounded = _chunk_cumsum_fwd(dt.unsqueeze(0), A, chunk_size) + chunk_states = _chunk_state_fwd(B.unsqueeze(0), x.unsqueeze(0), dt_rounded, dA_cumsum, seq_idx=seq_idx) + chunk_states, _ = _state_passing_fwd(rearrange(chunk_states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], + seq_idx=seq_idx, chunk_size=chunk_size) + chunk_states = rearrange(chunk_states, "... (p n) -> ... p n", n=dstate) + chunk_states = chunk_states.squeeze(0) + dA_cumsum = dA_cumsum.squeeze(0) + dt_rounded = dt_rounded.squeeze(0) + out = chunk_state_varlen(B, x, dt_rounded, dA_cumsum, cu_seqlens, chunk_states) + out_ref = [] + for b in range(batch): + x_s = x[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0) + B_s = B[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0) + dt_s = dt[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0) + dA_cumsum_s, dt_rounded_s = _chunk_cumsum_fwd(dt_s, A, chunk_size) + states = chunk_state(B_s, x_s, dt_rounded_s, dA_cumsum_s) + _, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum_s[:, :, :, -1], + chunk_size=chunk_size) + final_states = rearrange(final_states, "... (p n) -> ... p n", n=dstate) + out_ref.append(final_states) + out_ref = torch.cat(out_ref, dim=0) + print(f"Max diff = {(out - out_ref).abs().max().item()}") + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/tests/test_generation.py b/tests/test_generation.py new file mode 100644 index 00000000..16949023 --- /dev/null +++ b/tests/test_generation.py @@ -0,0 +1,113 @@ +import torch +import torch.nn.functional as F + +from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel +from mamba_ssm.models.config_mamba import MambaConfig +from mamba_ssm.utils.generation import InferenceParams + +import pytest + +from einops import rearrange, repeat + + +def test_generation(): + batch = 3 + seqlen = 20 + device = "cuda" + dtype = torch.float16 + + config = MambaConfig( + d_model=1024, + n_layer=4, + vocab_size=50277, + ssm_cfg=dict(layer="Mamba2"), + rms_norm=True, + residual_in_fp32=True, + fused_add_norm=True, + pad_vocab_size_multiple=16, + ) + torch.manual_seed(2357) + model = MambaLMHeadModel(config, device=device, dtype=dtype) + x = torch.randint(0, 1000, (batch, seqlen), device=device, dtype=torch.long) + out_ref = model(x).logits + prompt_len = seqlen // 2 + out = model.generate( + input_ids = x[:, :prompt_len], max_length=seqlen, output_scores=True, return_dict_in_generate=True, + cg=True, # Can turn off CUDA graph for easier debugging + # instead of sampling, we take output tokens from x, to get logits for testing + # For actual generation, don't pass in teacher_outputs + teacher_outputs=x, + ) + out_scores = torch.stack(out.scores, dim=1) + print(f"Max diff: {(out_scores - out_ref[:, prompt_len - 1: -1]).abs().max()}") + assert torch.allclose(out_scores, out_ref[:, prompt_len - 1: -1], rtol=1e-3, atol=1e-2) + + +def test_generation_varlen(): + seqlens = [170, 65, 100] + genlen = 20 + total_seqlen = sum(seqlens) + device = "cuda" + dtype = torch.float16 + + config = MambaConfig( + d_model=1024, + n_layer=4, + vocab_size=50277, + ssm_cfg=dict(layer="Mamba2"), + rms_norm=True, + residual_in_fp32=True, + fused_add_norm=True, + pad_vocab_size_multiple=16, + ) + torch.manual_seed(2357) + model = MambaLMHeadModel(config, device=device, dtype=dtype) + xs = [torch.randint(0, 1000, (1, seqlen), device=device, dtype=torch.long) for seqlen in seqlens] + + # Reference 1: Forward pass with seq_idx + x = torch.cat(xs, dim=1) + seq_idx = torch.cat([torch.full((ids.shape[1],), i, dtype=torch.int32, device=device) + for i, ids in enumerate(xs)], dim=0).unsqueeze(0) + cu_seqlens = F.pad(torch.tensor(seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0)) + out_ref = model(x, seq_idx=seq_idx).logits + # Only take the last @genlen logits of each sequence + out_ref = torch.cat([out_ref[:, cu_seqlens[i + 1] - genlen - 1:cu_seqlens[i + 1] - 1] + for i in range(len(seqlens))], dim=0) + + # Reference 2: Generate the last @genlen tokens of each sequence in a for loop + out_loop = [] + for input_ids in xs: + out = model.generate( + input_ids=input_ids[:, :-genlen], max_length=input_ids.shape[1], output_scores=True, + return_dict_in_generate=True, cg=True, teacher_outputs=input_ids, + ).scores + out_loop.append(torch.stack(out, dim=1)) + out_loop = torch.cat(out_loop, dim=0) + print(f"Max diff between ref1 and ref2: {(out_loop - out_ref).abs().max()}") + + # Varlen generation + input_ids = torch.cat([ids[:, :-genlen] for ids in xs], dim=1) + prompt_seqlens = [seqlen - genlen for seqlen in seqlens] + cu_seqlens = F.pad(torch.tensor(prompt_seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0)) + seq_idx = torch.cat([torch.full((seqlen,), i, dtype=torch.int32, device=device) + for i, seqlen in enumerate(prompt_seqlens)], dim=0).unsqueeze(0) + inference_params = InferenceParams(max_seqlen=2048, max_batch_size=len(seqlens)) + + scores, sequences = [], [] + # Both seq_idx and cu_seqlens must be passed in for varlen generation + logits = model(input_ids, inference_params=inference_params, seq_idx=seq_idx, cu_seqlens=cu_seqlens).logits + logits = rearrange(logits[0, cu_seqlens[1:] - 1], "b d -> b 1 d") + scores.append(logits) + # In practice we should sample. In this case we take from the teacher_output for testing + sampled_tokens = rearrange(torch.stack([ids[0, -genlen] for ids in xs], dim=0), "b -> b 1") + sequences.append(sampled_tokens) + for i in range(1, genlen): + inference_params.seqlen_offset += 1 + logits = model(sampled_tokens, inference_params=inference_params, num_last_tokens=1).logits + scores.append(logits) + # In practice we should sample. In this case we take from the teacher_output for testing + sampled_tokens = rearrange(torch.stack([ids[0, -genlen + i] for ids in xs], dim=0), "b -> b 1") + sequences.append(sampled_tokens) + out_varlen = torch.cat(scores, dim=1) + print(f"Max diff: {(out_varlen - out_ref).abs().max()}") + assert (out_varlen - out_ref).abs().max() < 5 * (out_loop - out_ref).abs().max()