Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tensor wrapper subclass] Add support for torchao.float8 mlp #1585

Draft
wants to merge 1 commit into
base: tensor_subclass_2
Choose a base branch
from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Dec 23, 2024

What does this PR do?

Multiple changes for thunder.jit to support a torchao.float8 MLP (see the test):

  • Add support of torch._scaled_mm
  • Update _general_jit_torch_autograd_function_apply_lookaside

- Add `scaled_mm`
- Change how the lookaside of `torch.autograd.Function.apply` applies dce
taking the failure of apex fused rms norm into consideration.

```python
@torch.no_grad()
@no_autocast
def FusedRMSNormAffineMixedDtypesFunction(t_0, t_1, tup11, f12, b13):
  # /usr/local/lib/python3.12/dist-packages/apex/normalization/fused_layer_norm.py:128:                 weight_ = weight.contiguous()
  # t_0: "cuda:0 f32[4, 5, 3, 2]"
  # t_1: "cuda:0 f32[3, 2]"

  # /usr/local/lib/python3.12/dist-packages/apex/normalization/fused_layer_norm.py:127:                 input_ = input.contiguous()
  t5 = ltorch.contiguous(t_0, memory_format=_torch_memory_format_0)  # t5: "cuda:0 f32[4, 5, 3, 2]"
    # t5 = prims.stride_order(t_0, (3, 2, 1, 0))  # t5: "cuda:0 f32[4, 5, 3, 2]"

  # /usr/local/lib/python3.12/dist-packages/apex/normalization/fused_layer_norm.py:128:                 weight_ = weight.contiguous()
  t6 = ltorch.contiguous(t_1, memory_format=_torch_memory_format_0)  # t6: "cuda:0 f32[3, 2]"
    # t6 = prims.stride_order(t_1, (1, 0))  # t6: "cuda:0 f32[3, 2]"
  (t10, t9) = apex_fused_rms_norm_forward_affine_mixed_dtypes(t5, (3, 2), t6, 1e-05)
  return t10
```
For this trace, `thunder.core.transforms.dce` replaces `t9` with `_`
then the augmented forward trace would lose the access to it. So by
reusing the augmented forward trace in the basic forward trace, `dce`
would not do so.

Signed-off-by: Masaki Kozuki <[email protected]>
@crcrpar crcrpar force-pushed the tensor_subclass_3 branch from d9ed305 to 8435406 Compare January 2, 2025 13:52
@github-actions github-actions bot removed the documentation Improvements or additions to documentation label Jan 2, 2025
@crcrpar
Copy link
Collaborator Author

crcrpar commented Jan 14, 2025

needs to fix the backward of torchao.float8 in tests. The cause seems to be the mismatch or row-major or column-major of the inputs to torch._scaled_mm. This could be dodged if we have a decomposition and let nvfuser or other fusion executor take care of it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant