Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] Add GQA Example #118

Merged
merged 14 commits into from
Feb 25, 2025
Merged

[Example] Add GQA Example #118

merged 14 commits into from
Feb 25, 2025

Conversation

LeiWang1999
Copy link
Contributor

This pull request introduces a new example script for flash attention with grouped query and key-value attention (GQA) in the examples/flash_attention/example_gqa_fwd_bshd.py file. The script includes functions for configuring and running the flash attention kernel, as well as comparing its performance to a reference implementation.

Key changes include:

  • New example script for flash attention with GQA:

    • Added a new file examples/flash_attention/example_gqa_fwd_bshd.py which includes the implementation of the flash attention kernel with grouped query and key-value attention.
  • Kernel function and configuration:

    • Defined the flashattn function that sets up the kernel function, including macros for matrix multiplication (MMA), softmax, and rescaling operations.
    • Implemented the get_configs function to generate different configurations for autotuning the kernel.
  • Reference implementation:

    • Added the ref_program function to provide a reference implementation of the attention mechanism using PyTorch operations.
  • Command-line interface:

This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques:
- `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang
- `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang

Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations.
Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts:
- Remove unused import (Profiler) in splitk example
- Simplify line breaks and improve code readability
- Standardize indentation and remove unnecessary whitespace
- Optimize atomic add and copy operations for better clarity
This commit introduces comprehensive block sparse attention benchmarks for different libraries:
- TileLang block sparse FMHA implementation
- Triton block sparse FMHA implementation
- PyTorch reference block sparse FMHA implementation
- FlashAttention dense FMHA reference implementation

The benchmarks include:
- Configurable benchmark parameters (batch size, heads, sequence length, etc.)
- Sparse mask generation using top-k and threshold methods
- Performance measurement for different sparse attention configurations
- Utility functions for mask generation and benchmarking
- Add Ruff linter ignore comments to benchmark files
- Improve code formatting and line breaks
- Remove unused imports
- Standardize print statement formatting
- Enhance code readability across multiple library benchmarks
- Implement AtomicAdd functions for BFLOAT16 and BFLOAT16x2 in CUDA common header
- Rename existing atomic add functions to use PascalCase (atomicAdd -> AtomicAdd)
- Add a new __pack_nv_bfloat162 function for packing BFLOAT16 values
- Update kernel and language customization to use new function names
- Add return type annotations in profiler module
…Attention in TileLang

This commit introduces a new example script `example_gqa_fwd_bshd.py` that demonstrates:
- Group Query Attention (GQA) implementation
- Flash Attention forward pass
- Performance benchmarking
- Configurable parameters for batch, heads, sequence length, and dimension
- Autotuning support
- Reference implementation comparison
This commit introduces a new module `phase.py` to modularize the IR lowering process by splitting the complex lowering pipeline into two distinct phases:
- `LowerAndLegalize`: Handles initial IR legalization and transformation
- `OptimizeForTarget`: Applies target-specific optimizations

The changes simplify the lowering logic in multiple files by extracting the transformation steps into reusable functions, improving code readability and maintainability.
@LeiWang1999 LeiWang1999 merged commit e93c7c4 into tile-ai:main Feb 25, 2025
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant