Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial blackwell support #677

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ jobs:
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-20.04]
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001']
cuda-version: ['11.8.0', '12.3.2']
torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0', '2.7.0.dev20250205']
cuda-version: ['11.8.0', '12.6.3', '12.8.0']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
Expand Down Expand Up @@ -93,13 +93,13 @@ jobs:

- name: Set up swap space
if: runner.os == 'Linux'
uses: pierotofy/set-swap-space@v1.0
uses: pierotofy/set-swap-space@master
with:
swap-size-gb: 10

- name: Install CUDA ${{ matrix.cuda-version }}
if: ${{ matrix.cuda-version != 'cpu' }}
uses: Jimver/[email protected].19
uses: Jimver/[email protected].21
id: cuda-toolkit
with:
cuda: ${{ matrix.cuda-version }}
Expand All @@ -121,17 +121,17 @@ jobs:
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \
minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118, '2.7': 124 }[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128 }[env['MATRIX_TORCH_VERSION']]; \
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
)
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
# pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
# Can't use --no-deps because we need cudnn etc.
# Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001
# Hard-coding this version of pytorch-triton for torch 2.7.0.dev20250205
pip install jinja2
pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_28_x86_64.whl
else
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
fi
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.2.4"
__version__ = "2.2.5"

from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
from mamba_ssm.modules.mamba_simple import Mamba
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/distributed/tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao.
# Copyright (c) 2025, Tri Dao.
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
from typing import Optional

Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/modules/block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.
from typing import Optional

import torch
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/modules/mamba2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

import math

Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/modules/mamba2_simple.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

import math
import torch
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/modules/mha.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

import math

Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/modules/mlp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.
from torch import nn
from torch.nn import functional as F

Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/modules/ssd_minimal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Albert Gu and Tri Dao.
# Copyright (c) 2025, Albert Gu and Tri Dao.
"""Minimal implementation of SSD.

This is the same as Listing 1 from the paper.
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/k_activations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

import torch

Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao.
# Copyright (c) 2025, Tri Dao.
# Implement dropout + residual + layer_norm / rms_norm.

# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/layernorm_gated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao.
# Copyright (c) 2025, Tri Dao.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/selective_state_update.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
"""
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/ssd_bmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

"""We want triton==2.1.0 or 2.2.0 for this
"""
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/ssd_chunk_scan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

"""We want triton==2.1.0 or 2.2.0 for this
"""
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/ssd_chunk_state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

"""We want triton==2.1.0 or 2.2.0 for this
"""
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/ssd_combined.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

"""We want triton==2.1.0 or 2.2.0 for this
"""
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/ssd_state_passing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

"""We want triton==2.1.0 or 2.2.0 for this
"""
Expand Down
14 changes: 13 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,22 @@ def append_nvcc_threads(nvcc_extra_args):
cc_flag.append("arch=compute_80,code=sm_80")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_87,code=sm_87")

cc_flag.append("-gencode")
cc_flag.append("arch=compute_89,code=sm_89")
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90a,code=sm_90a")
if bare_metal_version >= Version("12.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_100,code=sm_100") # B100
cc_flag.append("-gencode")
cc_flag.append("arch=compute_101,code=sm_101") # Thor
cc_flag.append("-gencode")
cc_flag.append("arch=compute_120,code=sm_120") # RTX50
cc_flag.append("-gencode")
cc_flag.append("arch=compute_120a,code=sm_120a")


# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
Expand Down