From 6e5a3a9c99a60586c36520f9d6ea60a1b1556440 Mon Sep 17 00:00:00 2001 From: Johnny Date: Wed, 22 Jan 2025 00:26:04 +0100 Subject: [PATCH 1/7] initial blackwell support --- .github/workflows/publish.yaml | 4 ++-- mamba_ssm/__init__.py | 2 +- mamba_ssm/distributed/tensor_parallel.py | 2 +- mamba_ssm/modules/block.py | 2 +- mamba_ssm/modules/mamba2.py | 2 +- mamba_ssm/modules/mamba2_simple.py | 2 +- mamba_ssm/modules/mha.py | 2 +- mamba_ssm/modules/mlp.py | 2 +- mamba_ssm/modules/ssd_minimal.py | 2 +- mamba_ssm/ops/triton/k_activations.py | 2 +- mamba_ssm/ops/triton/layer_norm.py | 2 +- mamba_ssm/ops/triton/layernorm_gated.py | 2 +- mamba_ssm/ops/triton/selective_state_update.py | 2 +- mamba_ssm/ops/triton/ssd_bmm.py | 2 +- mamba_ssm/ops/triton/ssd_chunk_scan.py | 2 +- mamba_ssm/ops/triton/ssd_chunk_state.py | 2 +- mamba_ssm/ops/triton/ssd_combined.py | 2 +- mamba_ssm/ops/triton/ssd_state_passing.py | 2 +- setup.py | 12 +++++++++++- 19 files changed, 30 insertions(+), 20 deletions(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 192f9562..273ee818 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -44,8 +44,8 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001'] - cuda-version: ['11.8.0', '12.3.2'] + torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0'] + cuda-version: ['11.8.0', '12.3.2', '12.6.3'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py index ac4f6e31..6280931e 100644 --- a/mamba_ssm/__init__.py +++ b/mamba_ssm/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.2.4" +__version__ = "2.2.5" from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from mamba_ssm.modules.mamba_simple import Mamba diff --git a/mamba_ssm/distributed/tensor_parallel.py b/mamba_ssm/distributed/tensor_parallel.py index 2d67b530..5d4f1000 100644 --- a/mamba_ssm/distributed/tensor_parallel.py +++ b/mamba_ssm/distributed/tensor_parallel.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao. +# Copyright (c) 2025, Tri Dao. # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py from typing import Optional diff --git a/mamba_ssm/modules/block.py b/mamba_ssm/modules/block.py index 1bd968a0..8ebb8dd1 100644 --- a/mamba_ssm/modules/block.py +++ b/mamba_ssm/modules/block.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. from typing import Optional import torch diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 36b16d47..ceeb3d04 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. import math diff --git a/mamba_ssm/modules/mamba2_simple.py b/mamba_ssm/modules/mamba2_simple.py index 77a6af28..cc51be4f 100644 --- a/mamba_ssm/modules/mamba2_simple.py +++ b/mamba_ssm/modules/mamba2_simple.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. import math import torch diff --git a/mamba_ssm/modules/mha.py b/mamba_ssm/modules/mha.py index 978f3ea4..0818394b 100644 --- a/mamba_ssm/modules/mha.py +++ b/mamba_ssm/modules/mha.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. import math diff --git a/mamba_ssm/modules/mlp.py b/mamba_ssm/modules/mlp.py index 33bab5c7..7e6fb16e 100644 --- a/mamba_ssm/modules/mlp.py +++ b/mamba_ssm/modules/mlp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. from torch import nn from torch.nn import functional as F diff --git a/mamba_ssm/modules/ssd_minimal.py b/mamba_ssm/modules/ssd_minimal.py index 9632ebd4..6e8d5382 100644 --- a/mamba_ssm/modules/ssd_minimal.py +++ b/mamba_ssm/modules/ssd_minimal.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Albert Gu and Tri Dao. +# Copyright (c) 2025, Albert Gu and Tri Dao. """Minimal implementation of SSD. This is the same as Listing 1 from the paper. diff --git a/mamba_ssm/ops/triton/k_activations.py b/mamba_ssm/ops/triton/k_activations.py index 79fa2cc6..1b0c2640 100644 --- a/mamba_ssm/ops/triton/k_activations.py +++ b/mamba_ssm/ops/triton/k_activations.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. import torch diff --git a/mamba_ssm/ops/triton/layer_norm.py b/mamba_ssm/ops/triton/layer_norm.py index 200b415a..a2699c4b 100755 --- a/mamba_ssm/ops/triton/layer_norm.py +++ b/mamba_ssm/ops/triton/layer_norm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao. +# Copyright (c) 2025, Tri Dao. # Implement dropout + residual + layer_norm / rms_norm. # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html diff --git a/mamba_ssm/ops/triton/layernorm_gated.py b/mamba_ssm/ops/triton/layernorm_gated.py index de4b2f48..33ccc0e1 100644 --- a/mamba_ssm/ops/triton/layernorm_gated.py +++ b/mamba_ssm/ops/triton/layernorm_gated.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao. +# Copyright (c) 2025, Tri Dao. # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. diff --git a/mamba_ssm/ops/triton/selective_state_update.py b/mamba_ssm/ops/triton/selective_state_update.py index d425bc72..a11c426c 100644 --- a/mamba_ssm/ops/triton/selective_state_update.py +++ b/mamba_ssm/ops/triton/selective_state_update.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this """ diff --git a/mamba_ssm/ops/triton/ssd_bmm.py b/mamba_ssm/ops/triton/ssd_bmm.py index 48fd4f06..4f505bcc 100644 --- a/mamba_ssm/ops/triton/ssd_bmm.py +++ b/mamba_ssm/ops/triton/ssd_bmm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. """We want triton==2.1.0 or 2.2.0 for this """ diff --git a/mamba_ssm/ops/triton/ssd_chunk_scan.py b/mamba_ssm/ops/triton/ssd_chunk_scan.py index fa5b813a..b7b1d7e6 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_scan.py +++ b/mamba_ssm/ops/triton/ssd_chunk_scan.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. """We want triton==2.1.0 or 2.2.0 for this """ diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index bb49c9a9..04625490 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. """We want triton==2.1.0 or 2.2.0 for this """ diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index 58a6e04a..54e7a3d9 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. """We want triton==2.1.0 or 2.2.0 for this """ diff --git a/mamba_ssm/ops/triton/ssd_state_passing.py b/mamba_ssm/ops/triton/ssd_state_passing.py index 63863b82..ebf0176d 100644 --- a/mamba_ssm/ops/triton/ssd_state_passing.py +++ b/mamba_ssm/ops/triton/ssd_state_passing.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. """We want triton==2.1.0 or 2.2.0 for this """ diff --git a/setup.py b/setup.py index 7c6196d7..0614a2d8 100755 --- a/setup.py +++ b/setup.py @@ -184,10 +184,20 @@ def append_nvcc_threads(nvcc_extra_args): cc_flag.append("arch=compute_80,code=sm_80") cc_flag.append("-gencode") cc_flag.append("arch=compute_87,code=sm_87") - + cc_flag.append("-gencode") + cc_flag.append("arch=compute_89,code=sm_89") if bare_metal_version >= Version("11.8"): cc_flag.append("-gencode") cc_flag.append("arch=compute_90,code=sm_90") + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90a,code=sm_90a") + if bare_metal_version >= Version("12.7"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_100,code=sm_100") # B100 + cc_flag.append("-gencode") + cc_flag.append("arch=compute_101,code=sm_101") # Thor + cc_flag.append("-gencode") + cc_flag.append("arch=compute_120,code=sm_100") # RTX50 # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as From 932bc7fc370b8c798045b89bdf8c1afcef068121 Mon Sep 17 00:00:00 2001 From: Johnny Date: Thu, 23 Jan 2025 22:30:59 +0100 Subject: [PATCH 2/7] Update publish.yaml --- .github/workflows/publish.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 273ee818..f8af10f4 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -45,7 +45,7 @@ jobs: os: [ubuntu-20.04] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0'] - cuda-version: ['11.8.0', '12.3.2', '12.6.3'] + cuda-version: ['11.8.0', '12.3.2', '12.6.3', '12.8.0'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) From a35151710abc04687038b807313a2086ffeba8b7 Mon Sep 17 00:00:00 2001 From: Johnny Date: Thu, 23 Jan 2025 22:31:49 +0100 Subject: [PATCH 3/7] Update setup.py --- setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0614a2d8..c10044de 100755 --- a/setup.py +++ b/setup.py @@ -197,7 +197,9 @@ def append_nvcc_threads(nvcc_extra_args): cc_flag.append("-gencode") cc_flag.append("arch=compute_101,code=sm_101") # Thor cc_flag.append("-gencode") - cc_flag.append("arch=compute_120,code=sm_100") # RTX50 + cc_flag.append("arch=compute_120,code=sm_120") # RTX50 + cc_flag.append("-gencode") + cc_flag.append("arch=compute_120a,code=sm_120a") # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as From 25e883bafe447f1c17de796936d83df3ae202a76 Mon Sep 17 00:00:00 2001 From: Johnny Date: Fri, 24 Jan 2025 23:54:27 +0100 Subject: [PATCH 4/7] . --- .github/workflows/publish.yaml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index f8af10f4..8b140f72 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -44,8 +44,8 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0'] - cuda-version: ['11.8.0', '12.3.2', '12.6.3', '12.8.0'] + torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0', '2.7.0.dev20250130'] + cuda-version: ['11.8.0', '12.6.3', '12.8.0'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) @@ -93,13 +93,13 @@ jobs: - name: Set up swap space if: runner.os == 'Linux' - uses: pierotofy/set-swap-space@v1.0 + uses: pierotofy/set-swap-space@master with: swap-size-gb: 10 - name: Install CUDA ${{ matrix.cuda-version }} if: ${{ matrix.cuda-version != 'cpu' }} - uses: Jimver/cuda-toolkit@v0.2.19 + uses: Jimver/cuda-toolkit@v0.2.20 id: cuda-toolkit with: cuda: ${{ matrix.cuda-version }} @@ -121,16 +121,16 @@ jobs: # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 # This code is ugly, maybe there's a better way to do this. export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \ - maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \ + minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118, '2.7': 121 }[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128 }[env['MATRIX_TORCH_VERSION']]; \ print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ ) if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then # pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} # Can't use --no-deps because we need cudnn etc. - # Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001 + # Hard-coding this version of pytorch-triton for torch 2.7.0.dev20250130 pip install jinja2 - pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl + pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl else pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} From f091680e8259073feafcdc8e59aa57b688c946b7 Mon Sep 17 00:00:00 2001 From: Johnny Date: Wed, 5 Feb 2025 11:05:37 +0100 Subject: [PATCH 5/7] Update publish.yaml --- .github/workflows/publish.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 8b140f72..2a7c2d54 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -44,7 +44,7 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0', '2.7.0.dev20250130'] + torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0', '2.7.0.dev20250205'] cuda-version: ['11.8.0', '12.6.3', '12.8.0'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. @@ -99,7 +99,7 @@ jobs: - name: Install CUDA ${{ matrix.cuda-version }} if: ${{ matrix.cuda-version != 'cpu' }} - uses: Jimver/cuda-toolkit@v0.2.20 + uses: Jimver/cuda-toolkit@v0.2.21 id: cuda-toolkit with: cuda: ${{ matrix.cuda-version }} @@ -121,16 +121,16 @@ jobs: # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 # This code is ugly, maybe there's a better way to do this. export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118, '2.7': 121 }[env['MATRIX_TORCH_VERSION']]; \ + minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118, '2.7': 124 }[env['MATRIX_TORCH_VERSION']]; \ maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128 }[env['MATRIX_TORCH_VERSION']]; \ print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ ) if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then # pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} # Can't use --no-deps because we need cudnn etc. - # Hard-coding this version of pytorch-triton for torch 2.7.0.dev20250130 + # Hard-coding this version of pytorch-triton for torch 2.7.0.dev20250205 pip install jinja2 - pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl + pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl else pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} From bd1113b1b5e69f61ffb411b4f29f03bbf059f63e Mon Sep 17 00:00:00 2001 From: Johnny Date: Wed, 5 Feb 2025 11:11:10 +0100 Subject: [PATCH 6/7] Update publish.yaml --- .github/workflows/publish.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 2a7c2d54..8fa796e8 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -131,7 +131,7 @@ jobs: # Hard-coding this version of pytorch-triton for torch 2.7.0.dev20250205 pip install jinja2 pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl + pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_28_x86_64.whl else pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} fi From b91b86d5e8c877108202da2c5c7166d8b5d12958 Mon Sep 17 00:00:00 2001 From: Johnny Date: Wed, 5 Feb 2025 12:43:15 +0100 Subject: [PATCH 7/7] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c10044de..1127fed3 100755 --- a/setup.py +++ b/setup.py @@ -191,7 +191,7 @@ def append_nvcc_threads(nvcc_extra_args): cc_flag.append("arch=compute_90,code=sm_90") cc_flag.append("-gencode") cc_flag.append("arch=compute_90a,code=sm_90a") - if bare_metal_version >= Version("12.7"): + if bare_metal_version >= Version("12.8"): cc_flag.append("-gencode") cc_flag.append("arch=compute_100,code=sm_100") # B100 cc_flag.append("-gencode")