Skip to content

Commit

Permalink
Better FLOP counting, compilable inference APIs (#189)
Browse files Browse the repository at this point in the history
See #184

Only supports forward pass for now, due to current limitations of
registering custom ops with torch compared to autograd functions. Some
of those limitations are:

* No stable interface for supporting autocasting to fp16/bf16,
* Gradient scaling doesn't seem to be supported either, leading to
training instability.

* Ops cannot indicate that they expect contiguous operands, and need to
call `.contiguous()` within, and this incurs additional tensor copy
costs, and brings down throughput (in some cases it's hard to even tell
the difference between compiled and eager.)

- [x] Experimental torch ops (through torch.library; forward pass only)
- [x] FLOP counting per #184
- [x] Unit tests for FLOP counters
- [x] Confirm graph doesn't break -- (for some reason, when running with
`torch.no_grad`, compiled graph isn't dumped to file with
`TORCH_COMPILE_DEBUG=1`, but logs and assertions confirm it's working)
- [x] Check earlier torch versions to make sure nothing breaks, and unit
tests pass
- [x] Unit test for torch.compile with fullgraph
  • Loading branch information
alihassanijr authored Jan 3, 2025
1 parent ee45f9d commit 86e3b69
Show file tree
Hide file tree
Showing 17 changed files with 1,778 additions and 380 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
* Bug fixes in FLOP counter (with fvcore)
* Bugs introduced since FNA backwards were fixed.
* Finally added unit tests.
* Better FLOP counting support
* Rename instances of FLOP counting with fvcore to MACs, since that's what fvcore reports.
* Add experimental support for torch's native flop counter
* Better documentation

## [0.17.3] - 2024-11-01
* Bug fix for torch < 2.4
Expand Down
1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Welcome to NATTEN docs!
* Describes how to check NATTEN's compatibility with your system, and its available features.
* Describes how to use torch modules `NeighborhoodAttention1D`, `NeighborhoodAttention2D`, and `NeighborhoodAttention3D`.
* Describes when and how to use ops directly instead of modules.
* [FLOP counting](flops.md)
* [Fused Neighborhood Attention](fna/) (new)
* [Quick start guide](fna/fna-quickstart.md)
* [Fused vs unfused](fna/fused-vs-unfused.md): differences in supported features, layouts, etc.
Expand Down
188 changes: 188 additions & 0 deletions docs/flops.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
## FLOP counting with NATTEN

NATTEN offers multiple solutions for counting FLOPs/MACs.

However, note that due to the variety of solutions available for this, it is important to have a
consistent definitions of FLOPs and how they are counted.

### Rules
Since NATTEN is limited to the dot product attention operator, here we note the general rules for
counting FLOPs/MACs in NATTEN:

Following standard practice, only Tensor Core (TC) FLOPs are counted.
Softmax, masking, attention scale, and positional biases are NOT included in the FLOP/MAC count.

For a GEMM problem with shape `(M, N, K)`, MACs (Multiply-and-Accumulate operations) is simply
`M * N * K`, while FLOPs (Floating Point Operations) considers multiply and accumulate separate
operations, and ends up being `2 * M * N * K`.

Causal masking also does **not** affect FLOPs/MACs today, but this behavior may change in the
future. It makes sense to measure in causal masking for various reasons, but the actual FLOP
decrease is not so granular in practice. More details will be added about this in the future.
Please refer to [#184](https://github.com/SHI-Labs/NATTEN/issues/184#issuecomment-2505022903) for
more information.

### Manual
All users can integrate NATTEN's FLOP/MAC counting logic into their own FLOP counting logic:

```python
from natten.flops import (
fna_flop_count,
na_qk_flop_count,
na_av_flop_count
)
```

If you're using [Fused Neighborhood Attention (FNA)](fna/), or just prefer to compute FLOPs/MACs for
the entire attention operator all at once, use `fna_flop_count`:

```python
fna_flop_count(
q_shape: torch.Size | Sequence[int],
k_shape: torch.Size | Sequence[int],
v_shape: torch.Size | Sequence[int],
kernel_size: Sequence[int],
dilation: Sequence[int],
is_causal: Sequence[bool],
is_heads_last: bool,
return_macs: bool = False
) -> int
```

Simply pass the shape of your QKV tensors, and NA-related arguments (`kernel_size`, `dilation`,
`is_causal` as tuples), and finally, two key arguments:
* `is_heads_last`: determines your QKV tensor layout. If `True`, your QKV is expected to assume the
layout `[batch, *, heads, dim_per_head]`, and otherwise, `[batch, heads, *, dim_per_head`.
NOTE: Fused ops in NATTEN (FNA) currently use the "heads last" layout, so be sure to pass `True`
for them. Unfused ops in NATTEN use the "heads first" layout.
* `return_macs`: If `True`, returns MACs instead of FLOPs. Please be very careful when setting this.

If you're using unfused implementations, or prefer to compute FLOPs/MACs for the dot product and
weighting operations separately, use `na_qk_flop_count` and `na_av_flop_count`:

```python
na_qk_flop_count(
q_shape: torch.Size | Sequence[int],
k_shape: torch.Size | Sequence[int],
kernel_size: Sequence[int],
dilation: Sequence[int],
is_causal: Sequence[bool],
is_heads_last: bool,
return_macs: bool = False,
) -> int

na_av_flop_count(
a_shape: torch.Size | Sequence[int],
v_shape: torch.Size | Sequence[int],
kernel_size: Sequence[int],
dilation: Sequence[int],
is_causal: Sequence[bool],
is_heads_last: bool,
return_macs: bool = False,
) -> int
```

The APIs are mostly similar to the fused variant, and you need to pay extra attention to setting
`is_heads_last` and `return_macs`.

### PyTorch (>= 2.4.0)

PyTorch itself now supports counting model FLOPs, but obviously models with custom operators would
require extra work and `torch >= 2.4.0`.

**NOTE:** PyTorch reports FLOPs, not MACs!

PyTorch only counts FLOPs based on operations registered with torch.library. This excludes autograd
functions, which historically have been the only choice for binding custom operators to PyTorch and
getting autograd support.
NATTEN will eventually migrate to registering with torch.library, but that process is on hold until
certain features available in autograd functions are supported in the new API.
More details in #170.

In the meantime, we offer an experimental interface for NATTEN ops registered with torch.library,
which only supports inference, and only implements FNA operations.
Instead of using the `na*d` operations from `natten.functional` you will need to use the `na*d`
operations from `natten.experimental`:


```python
# from natten.functional import na1d, na2d, na3d
from natten.experimental import na1d, na2d, na3d
```

or if you're using NATTEN's torch modules (`NeighborhoodAttention*D`), you need to pass an
additional argument:

```python
from natten import (
NeighborhoodAttention1D,
NeighborhoodAttention2D
NeighborhoodAttention3D
)

m_1d = NeighborhoodAttention1D(dim=128, num_heads=4, kernel_size=7, use_experimental_ops=True)
m_2d = NeighborhoodAttention2D(dim=128, num_heads=4, kernel_size=7, use_experimental_ops=True)
m_3d = NeighborhoodAttention3D(dim=128, num_heads=4, kernel_size=7, use_experimental_ops=True)
```

However, note that forward pass will fail if fused NA is disabled / not supported, but using
experimental ops is enabled.


As long as you use these experimental ops, you can:
* Successfully build NA-based models with `torch.compile` WITHOUT graph breaks,
* Count flops with `torch.utils.flops`

For more information on the `na*d` interface, please refer to our [frontend docs](frontend.md) or
[Fused Neighborhood Attention docs](fna/).

To count FLOPs with PyTorch's native FLOP counter:

```python
from torch.utils.flop_counter import FlopCounterMode

flop_counter = FlopCounterMode()

with flop_counter:
y = model_with_natten_ops(x)

total_flops = flop_counter.get_total_flops()
```

**NOTE:**

### FVCore

NATTEN can be paired with [fvcore](https://github.com/facebookresearch/fvcore)'s counter, as long as you correctly
import and add NATTEN's handles to the counter.

The only requirement is having fvcore installed.

**NOTE:** fvocre reports MACs, not FLOPs!

```python
from fvcore.nn import FlopCountAnalysis
from natten.flops import add_natten_handle

# ...

flop_ctr = FlopCountAnalysis(model, input)
flop_ctr = add_natten_handle(flop_ctr)

# ...
```

Alternatively, you can use our `get_flops` interface and count FLOPs for a model that may contain
NATTEN ops/modules:
```python
from natten.flops import get_flops

flops = get_flops(model, input)
```

#### Installing fvcore
fvcore is available through PyPI:

```shell
pip install fvcore
```
33 changes: 2 additions & 31 deletions docs/frontend.md
Original file line number Diff line number Diff line change
Expand Up @@ -417,34 +417,5 @@ disable_tiled_na()
enable_tiled_na()
```

### FLOP counting with FVCore

NATTEN can be paired with [fvcore](https://github.com/facebookresearch/fvcore)'s FLOP counter, as long as you correctly
import and add NATTEN's handles to the flop counter.

```python
from fvcore.nn import FlopCountAnalysis
from natten.flops import add_natten_handle

# ...

flop_ctr = FlopCountAnalysis(model, input)
flop_ctr = add_natten_handle(flop_ctr)

# ...
```

Alternatively, you can use our `get_flops` interface and count FLOPs for a model that may contain
NATTEN ops/modules:
```python
from natten.flops import get_flops

flops = get_flops(model, input)
```

#### Installing fvcore
fvcore is available through PyPI:

```shell
pip install fvcore
```
### Counting FLOPs
Please refer to our [FLOP counting guide](flops.md) for more information.
Loading

0 comments on commit 86e3b69

Please sign in to comment.