Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tensor core instruction shape #19

Merged
merged 2 commits into from
Mar 9, 2025

Conversation

nlaanait
Copy link
Contributor

@nlaanait nlaanait commented Mar 8, 2025

Major Changes

  1. Change MMA_K Instruction shape for the "tensor_core" benchmark to conform with shape+stride of Inputs
    ref: https://github.com/NVIDIA/cutlass/blob/df18f5e4f5de76bed8be1de8e4c245f2f5ec3020/include/cute/atom/mma_traits_sm80.hpp#L147-L160

Minor Changes

  1. Registered a "tensor_core" benchmark to benchmarks

@@ -968,14 +968,15 @@ struct MatrixMultiplication[algorithm: StringLiteral]:
block_dim=(NUM_THREADS),
)
elif algorithm == "tensor_core":
constrained[""]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this left over from debugging?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct- was planning on constraining the inputs dtype to float32 (to make the tensor_core code path happy) but realized it's probably more robust to come up with a mapping of dtype to supported MMA instructions to generalize the examples to other input types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.
LMK what you think of the mapping above, happy to contribute it in a separate PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, let me check and see what we want to do to generalize this over multiple datatypes.

Copy link
Collaborator

@BradLarson BradLarson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the addition and fix! Had one comment on a possibly extraneous precondition, but otherwise looks good. I can update the example in the MAX repo with the same, to match.

@BradLarson BradLarson merged commit 5039ede into modular:main Mar 9, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants