-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for torch FlopCounterMode (register torch.library.custom_op) #184
Comments
there's also this mysterious if that's the case, then perhaps the example could be simplifed to just: from natten.functional import na2d
from torch.utils.flop_counter import FlopCounterMode
with FlopCounterMode(display=True):
na2d(q, k, v, 1, 3) |
@Birch-san thanks for bringing this to our attention; I didn't realize this feature existed. Unfortunately, registering NATTEN ops with torch is still pretty buggy (see #170). We tried doing it last month to get torch.compile support without NATTEN ops breaking the graph, and that part of it works as of torch 2.4. However, support for FP16/BF16 in custom ops gets a little tricky, which led to pytorch/pytorch#137437 . As I understand it, it's still not possible to register custom ops and get FP16 scaling the way you would with autograd functions, and this leads to unstable mixed-precision training (i.e. collapses to nan loss within an epoch). Let me look into the alternatives you shared and see if we can work them into NATTEN for now without changing the way NATTEN ops are registered with torch/autograd. |
@Birch-san so far it doesn't look like pretty much all of these assume underlying ops are being registered as torch ops directly. Given that doing so is still buggy when it comes to training with mixed precision, which is blocking #170, here's what I propose: Assuming users are limited to doing inference only, we could expose an "experimental" op interface (i.e. Let me know if that works for you. |
woah, getting breakless torch.compile for inference would be a nice bonus. the flop count this provides for forward pass would be helpful; the main benchmark we care about is inference speed. there's some debate over "should I deduct flops from my count when I exploit sparsity". so yes, let's go for that experimental op interface. as for backward… I wonder if there's a way to participate in torch operation dispatch without registering as a custom op? I don't know how the mechanism works, but maybe there's some specific dispatch-related API that could be invoked? or is the whole point of dispatch that it's a mechanism that requires registration of ops? |
Yeah the breakless torch.compile is possible, although there's multiple issues with that experimental interface. One is the pytorch bug I referenced, which forces custom ops like the ones in NATTEN to move tensors around in certain settings, just because there's no guarantee that all operands will be handed off to the custom op already contiguous. Another, which affects training mostly, is the gradient scaling support being flaky right now. Both of these are issues with the "new" way of registering ops. To be fair though, it is a tough thing for PyTorch to work out with so many assumptions that just can't be made about how custom ops behave. Regarding counting flops without registering, I'll take another look, but last I checked I think there wasn't a way around registering custom ops. I'm not at all sure, but I think you're right in that the whole point of the dispatcher, so it kind of goes hand in hand with registering ops. Regarding counting FLOPs when doing causal masking, that actually is a problem in NATTEN as well, but I'd say not as significant unless we're using a really big window size across that axis. I'll put that on my todo list. I'll try and get that experimental interface in soon. |
See SHI-Labs#184 Only supports forward pass for now, due to current limitations of registering custom ops with torch compared to autograd functions. Some of those limitations are: * No stable interface for supporting autocasting to fp16/bf16, * Gradient scaling doesn't seem to be supported either, leading to training instability. * Ops cannot indicate that they expect contiguous operands, and need to call `.contiguous()` within, and this incurs additional tensor copy costs, and brings down throughput (in some cases it's hard to even tell the difference between compiled and eager.)
See SHI-Labs#184 Only supports forward pass for now, due to current limitations of registering custom ops with torch compared to autograd functions. Some of those limitations are: * No stable interface for supporting autocasting to fp16/bf16, * Gradient scaling doesn't seem to be supported either, leading to training instability. * Ops cannot indicate that they expect contiguous operands, and need to call `.contiguous()` within, and this incurs additional tensor copy costs, and brings down throughput (in some cases it's hard to even tell the difference between compiled and eager.)
See #184 Only supports forward pass for now, due to current limitations of registering custom ops with torch compared to autograd functions. Some of those limitations are: * No stable interface for supporting autocasting to fp16/bf16, * Gradient scaling doesn't seem to be supported either, leading to training instability. * Ops cannot indicate that they expect contiguous operands, and need to call `.contiguous()` within, and this incurs additional tensor copy costs, and brings down throughput (in some cases it's hard to even tell the difference between compiled and eager.) - [x] Experimental torch ops (through torch.library; forward pass only) - [x] FLOP counting per #184 - [x] Unit tests for FLOP counters - [x] Confirm graph doesn't break -- (for some reason, when running with `torch.no_grad`, compiled graph isn't dumped to file with `TORCH_COMPILE_DEBUG=1`, but logs and assertions confirm it's working) - [x] Check earlier torch versions to make sure nothing breaks, and unit tests pass - [x] Unit test for torch.compile with fullgraph
would it be possible to register operations such as na2d using
torch.library.custom_op
,or otherwise ensure that they participate in operation dispatch?
torch's built in flop counter, FlopCounterMode, hooks into
__torch_dispatch__
,https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505
but na2d() doesn't seem to trigger
__torch_dispatch__
.Rather, if I manually modify
torch::utils::flop_counter::FlopCounterMode#_count_flops
with anelse
to print any time it sees an operation but doesn't have a counter for it: it never prints a na2d operation. So even if I register a flop counting algorithm for the na2d operation, FlopCounterMode would not be able to utilize that counter.This is the workaround that I currently do to wrap na2d such that it's eligible for flop counting.
And my guess at how to implement the flop counting algorithm.
it would be nice if this example could be simplified to something like:
For reference, xformers custom ops dispatch correctly, and they also provide flop counting functions. I was able to register their custom mappings into FlopCounterMode like so:
https://github.com/Birch-san/booru-embed/blob/main/script/attn_bench.py
The text was updated successfully, but these errors were encountered: