Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for torch FlopCounterMode (register torch.library.custom_op) #184

Open
Birch-san opened this issue Nov 19, 2024 · 5 comments
Open

Comments

@Birch-san
Copy link

Birch-san commented Nov 19, 2024

would it be possible to register operations such as na2d using torch.library.custom_op,
or otherwise ensure that they participate in operation dispatch?

torch's built in flop counter, FlopCounterMode, hooks into __torch_dispatch__,
https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505
but na2d() doesn't seem to trigger __torch_dispatch__.

Rather, if I manually modify torch::utils::flop_counter::FlopCounterMode#_count_flops with an else to print any time it sees an operation but doesn't have a counter for it: it never prints a na2d operation. So even if I register a flop counting algorithm for the na2d operation, FlopCounterMode would not be able to utilize that counter.

This is the workaround that I currently do to wrap na2d such that it's eligible for flop counting.
And my guess at how to implement the flop counting algorithm.

from typing import Protocol

from torch.utils.flop_counter import FlopCounterMode, bmm_flop
from torch._ops import OpOverloadPacket
from torch import FloatTensor

class FlopCountFn(Protocol):
    @staticmethod
    def __call__(*args, **kwargs) -> int: ...

@torch.library.custom_op("natten::na2d", mutates_args=())
def floppy_na2d(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    scale: float,
    kernel_size: int,
) -> torch.Tensor:
    # the reason I'm using FusedNeighborhoodAttention2D instead of na_2d is for torch.jit.trace
    # script-mode support, i.e. to trace model with stable-fast but activate NATTEN conditionally
    # in a script-mode block based on whether canvas size has enough sparsity
    # (not shown in this example; it's not actually relevant to FLOP-counting)
    from natten.functional import FusedNeighborhoodAttention2D
    # kv_splits for backward may be suboptimal, but torch.jit.trace can't compile the true heuristic (isinstance(int) fails due to tensorization)
    tiling_config_backward = ((8, 8), (8, 8), (1, 1), False)
    # tiling_config_forward = get_default_tiling_config_for_fna_forward(2, q, 1)
    tiling_config_forward = ((8, 8), (8, 8))
    return FusedNeighborhoodAttention2D.apply(
        q,
        k,
        v,
        None,
        kernel_size,
        1,
        False,
        scale,
        tiling_config_forward,
        tiling_config_backward,
    )

def na2d_flop_count(
    q: Size,
    k: Size,
    v: Size,
    kernel_size: int,
):
    """
    Count flops for na2d.

    NB: We can assume that v == k
    """
    from torch.utils.flop_counter import bmm_flop
    b, s_qh, s_qw, h, d_q = q
    _b2, s_kh, s_kw, _h2, _d2 = k
    _b3, s_vh, s_vw, _h3, d_v = v
    assert b == _b2 == _b3 and h == _h2 == _h3 and d_q == _d2 and s_kh == s_vh and s_kw == s_vw and s_qh == s_kh and s_qw == s_kw and d_q == _d2
    s_q = s_qh * s_qw
    s_kv = kernel_size**2
    total_flops = 0
    # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
    total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_kv))
    # scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v]
    total_flops += bmm_flop((b * h, s_q, s_kv), (b * h, s_kv, d_v))
    return total_flops

def na2d_flop(
    q: Size,
    k: Size,
    v: Size,
    scale: float,
    kernel_size: int,
    *args,
    out_shape=None,
    **kwargs,
) -> int:
    return na2d_flop_count(q, k, v, kernel_size)

custom_mapping: dict[OpOverloadPacket, FlopCountFn] = {
    na2d_op: na2d_flop,
}

# pseudocode
q: FloatTensor
k: FloatTensor
v: FloatTensor
with FlopCounterMode(display=True, custom_mapping=custom_mapping):
    floppy_na2d(q, k, v, 1, 3)

it would be nice if this example could be simplified to something like:

from typing import Protocol

from natten.functional import na2d
from natten.flop_count import na2d_flops
from torch.utils.flop_counter import FlopCounterMode
from torch._ops import OpOverloadPacket

class FlopCountFn(Protocol):
    @staticmethod
    def __call__(*args, **kwargs) -> int: ...

custom_mapping: dict[OpOverloadPacket, FlopCountFn] = {
    na2d.OPERATOR: na2d_flops,
}

with FlopCounterMode(display=True, custom_mapping=custom_mapping):
    na2d(q, k, v, 1, 3)

For reference, xformers custom ops dispatch correctly, and they also provide flop counting functions. I was able to register their custom mappings into FlopCounterMode like so:
https://github.com/Birch-san/booru-embed/blob/main/script/attn_bench.py

@Birch-san
Copy link
Author

Birch-san commented Nov 19, 2024

there's also this mysterious @register_flop_formula; I wonder if that does something like registering a flop count for an operation into the global registry, so that users don't have to register it manually via custom_mapping.

if that's the case, then perhaps the example could be simplifed to just:

from natten.functional import na2d
from torch.utils.flop_counter import FlopCounterMode

with FlopCounterMode(display=True):
    na2d(q, k, v, 1, 3)

@alihassanijr
Copy link
Member

@Birch-san thanks for bringing this to our attention; I didn't realize this feature existed.

Unfortunately, registering NATTEN ops with torch is still pretty buggy (see #170). We tried doing it last month to get torch.compile support without NATTEN ops breaking the graph, and that part of it works as of torch 2.4.

However, support for FP16/BF16 in custom ops gets a little tricky, which led to pytorch/pytorch#137437 . As I understand it, it's still not possible to register custom ops and get FP16 scaling the way you would with autograd functions, and this leads to unstable mixed-precision training (i.e. collapses to nan loss within an epoch).

Let me look into the alternatives you shared and see if we can work them into NATTEN for now without changing the way NATTEN ops are registered with torch/autograd.

@alihassanijr
Copy link
Member

@Birch-san so far it doesn't look like pretty much all of these assume underlying ops are being registered as torch ops directly. Given that doing so is still buggy when it comes to training with mixed precision, which is blocking #170, here's what I propose:

Assuming users are limited to doing inference only, we could expose an "experimental" op interface (i.e. natten.experimental.functional, and register those with torch, and that way, inference users can enjoy both torch.compile without graph breaks, and pytorch's own FlopCounter.

Let me know if that works for you.

@Birch-san
Copy link
Author

woah, getting breakless torch.compile for inference would be a nice bonus.

the flop count this provides for forward pass would be helpful; the main benchmark we care about is inference speed.
in fact, instrumenting forward flops probably provides a workaround for counting backwards flops too.
since it would give you all the accounting for forward calls, then when you compute the backward there'll be some missing ops but you can manually add some figure based on the natten ops you counted in the forward pass.

there's some debate over "should I deduct flops from my count when I exploit sparsity".
the torch flop counter does not halve flop count of causal attention.
the flash attention flop counter does.
keeping FLOP count constant lets you use FLOP/s to compare latencies.
but reducing FLOP count as sparsity grows helps you understand algorithmic complexity and perhaps energy usage.

so yes, let's go for that experimental op interface.

as for backward… I wonder if there's a way to participate in torch operation dispatch without registering as a custom op? I don't know how the mechanism works, but maybe there's some specific dispatch-related API that could be invoked? or is the whole point of dispatch that it's a mechanism that requires registration of ops?

@alihassanijr
Copy link
Member

Yeah the breakless torch.compile is possible, although there's multiple issues with that experimental interface. One is the pytorch bug I referenced, which forces custom ops like the ones in NATTEN to move tensors around in certain settings, just because there's no guarantee that all operands will be handed off to the custom op already contiguous. Another, which affects training mostly, is the gradient scaling support being flaky right now. Both of these are issues with the "new" way of registering ops. To be fair though, it is a tough thing for PyTorch to work out with so many assumptions that just can't be made about how custom ops behave.

Regarding counting flops without registering, I'll take another look, but last I checked I think there wasn't a way around registering custom ops. I'm not at all sure, but I think you're right in that the whole point of the dispatcher, so it kind of goes hand in hand with registering ops.

Regarding counting FLOPs when doing causal masking, that actually is a problem in NATTEN as well, but I'd say not as significant unless we're using a really big window size across that axis. I'll put that on my todo list.

I'll try and get that experimental interface in soon.

alihassanijr added a commit to alihassanijr/NATTEN-Torch that referenced this issue Dec 10, 2024
See SHI-Labs#184

Only supports forward pass for now, due to current limitations of
registering custom ops with torch compared to autograd functions. Some
of those limitations are:

* No stable interface for supporting autocasting to fp16/bf16,
  * Gradient scaling doesn't seem to be supported either, leading to
    training instability.

* Ops cannot indicate that they expect contiguous operands, and need to
  call `.contiguous()` within, and this incurs additional tensor copy
  costs, and brings down throughput (in some cases it's hard to even
  tell the difference between compiled and eager.)
alihassanijr added a commit to alihassanijr/NATTEN-Torch that referenced this issue Dec 10, 2024
See SHI-Labs#184

Only supports forward pass for now, due to current limitations of
registering custom ops with torch compared to autograd functions. Some
of those limitations are:

* No stable interface for supporting autocasting to fp16/bf16,
  * Gradient scaling doesn't seem to be supported either, leading to
    training instability.

* Ops cannot indicate that they expect contiguous operands, and need to
  call `.contiguous()` within, and this incurs additional tensor copy
  costs, and brings down throughput (in some cases it's hard to even
  tell the difference between compiled and eager.)
alihassanijr added a commit that referenced this issue Jan 3, 2025
See #184

Only supports forward pass for now, due to current limitations of
registering custom ops with torch compared to autograd functions. Some
of those limitations are:

* No stable interface for supporting autocasting to fp16/bf16,
* Gradient scaling doesn't seem to be supported either, leading to
training instability.

* Ops cannot indicate that they expect contiguous operands, and need to
call `.contiguous()` within, and this incurs additional tensor copy
costs, and brings down throughput (in some cases it's hard to even tell
the difference between compiled and eager.)

- [x] Experimental torch ops (through torch.library; forward pass only)
- [x] FLOP counting per #184
- [x] Unit tests for FLOP counters
- [x] Confirm graph doesn't break -- (for some reason, when running with
`torch.no_grad`, compiled graph isn't dumped to file with
`TORCH_COMPILE_DEBUG=1`, but logs and assertions confirm it's working)
- [x] Check earlier torch versions to make sure nothing breaks, and unit
tests pass
- [x] Unit test for torch.compile with fullgraph
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants