Skip to content

Commit

Permalink
Initial draft of AI operations recipe. (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
BradLarson authored Mar 4, 2025
1 parent a4fc268 commit 1387c37
Show file tree
Hide file tree
Showing 10 changed files with 1,118 additions and 0 deletions.
8 changes: 8 additions & 0 deletions custom-ops-ai-applications/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# pixi environments
.pixi
*.egg-info
# magic environments
.magic
.env
# build products
operations.mojopkg
131 changes: 131 additions & 0 deletions custom-ops-ai-applications/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Custom Operations: Applications in AI models

In this recipe, we will cover:

* Building a top-K token sampler for GPU and CPU.
* Implementing FlashAttention-2 as a fused custom operation.
* Accessing GPU hardware features, like Tensor Cores, from MAX.

We'll walk through two examples that

* illustrate real-world applications of custom MAX Graph operations in AI
models,
* applies seven optimizations that sequentially improve GPU performance,
* and demonstrates the performance benefits with benchmarks of each step.

Let's get started.

## Requirements

Please make sure your system meets our
[system requirements](https://docs.modular.com/max/get-started).

To proceed, ensure you have the `magic` CLI installed:

```bash
curl -ssL https://magic.modular.com/ | bash
```

or update it via:

```bash
magic self-update
```

### GPU requirements

These examples can all be run on either a CPU or GPU. To run them on a GPU,
ensure your system meets
[these GPU requirements](https://docs.modular.com/max/faq/#gpu-requirements):

* Officially supported GPUs: NVIDIA Ampere A-series (A100/A10), or Ada
L4-series (L4/L40) data center GPUs. Unofficially, RTX 30XX and 40XX series
GPUs have been reported to work well with MAX.
* NVIDIA GPU driver version 555 or higher. [Installation guide here](https://www.nvidia.com/download/index.aspx).

## Quick start

1. Download the code for this recipe using git:

```bash
git clone https://github.com/modular/max-recipes.git
cd max-recipes/custom-ai-applications
```

2. Run the examples:

```bash
magic run top_k
magic run fused_attention
```

3. Run the benchmarks of the top-K token sampler to sse how much faster it
runs on GPU:

```bash
magic run benchmarks
```

## Optimizing top-K token sampling for GPUs in MAX

AI models in [MAX](https://docs.modular.com/max/intro) are built as
computational graphs using the
[MAX Graph API](https://docs.modular.com/max/tutorials/get-started-with-max-graph-in-python).
MAX contains within it a powerful graph compiler that can take these graphs
and optimize them for best performance on a wide range of hardware.

Each node in a MAX Graph is defined by an operation that performs a calculation
on zero or more inputs and produces one or more outputs. These inputs and
outputs tend to be in the form of tensors, and the operations are usually
data-parallel calculations that are accelerated on CPUs or GPUs. In MAX,
these operations are written using [Mojo](https://docs.modular.com/mojo/manual/),
a Python-family language built for high-performance computation.

Large language models rely on token samplers to improve the quality of the text
generated from the model, as well as add interesting variability to the output.
One sampling technique is top-K token sampling, and this example provides both
CPU and GPU implementations of this algorithm. The GPU implementation
demonstrates how to accelerate the sampling via hardware features.

The following can be used to run the top-K token sampling demo:

```bash
magic run top_k
```

## Implementing a fused operation for FlashAttention-2

Modern Transformer-based language models are constructed around the attention
mechanism. Optimizing how attention is performed is a key driver in improving
large language model performance. One such optimization is the
[FlashAttention-2](https://arxiv.org/abs/2307.08691) layer. In this example,
you'll see how to implement FlashAttention-2 as a fused operation that runs on
the GPU in MAX using Mojo.

To run the example, use the following command:

```bash
magic run fused_attention
```

## Conclusion

In this recipe, we've demonstrated how to create custom MAX Graph operations
that perform functions important in modern AI models: top-K token sampling and
the FlashAttention-2 attention layer. Each provides examples of how complex
calculations can be constructed using MAX and Mojo and targeted towards
hardware features in GPUs.

## Next Steps

* Follow [our tutorial for building a custom operation from scratch](https://docs.modular.com/max/tutorials/build-custom-ops).

* Explore MAX's [documentation](https://docs.modular.com/max/) for additional
features. The [`gpu`](https://docs.modular.com/mojo/stdlib/gpu/) module has
detail on Mojo's GPU programming functions and types, and the documentation
on [`@compiler.register`](https://docs.modular.com/max/api/mojo-decorators/compiler-register/)
shows how to register custom graph operations.

* Join our [Modular Forum](https://forum.modular.com/) and [Discord community](https://discord.gg/modular) to share your experiences and get support.

We're excited to see what you'll build with MAX! Share your projects and experiences with us using `#ModularAI` on social media.
135 changes: 135 additions & 0 deletions custom-ops-ai-applications/benchmarks.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from operations.top_k import TopK
from gpu.host import DeviceContext
from utils import IndexList
from max.driver import cpu
from max.tensor import (
ManagedTensorSlice,
InputTensor,
OutputTensor,
StaticTensorSpec,
)
from random import rand
from memory import UnsafePointer
from runtime.asyncrt import DeviceContextPtr
from benchmark import ThroughputMeasure, BenchId, BenchMetric, Bench, Bencher
from bit import log2_floor
from sys import sizeof, has_nvidia_gpu_accelerator
from memory import AddressSpace


def top_k():
alias batch_size = 30_000
alias K = 32
alias els = batch_size * K
alias rank = 2
alias shape = IndexList[rank](batch_size, K)
alias val_dtype = DType.float32
alias idx_dtype = DType.int32

# Slightly better performance compared to `create_unknown`. Using global
# address space doesn't improve perf for GPU.
alias val_spec = StaticTensorSpec[val_dtype, rank](
shape=(batch_size, K),
strides=(K, 1),
alignment=sizeof[val_dtype](),
address_space=AddressSpace.GENERIC,
exclusive=True,
in_lambda=None,
out_lambda=None,
)
alias idx_spec = StaticTensorSpec[idx_dtype, rank](
shape=(batch_size, K),
strides=(K, 1),
alignment=sizeof[idx_dtype](),
address_space=AddressSpace.GENERIC,
exclusive=True,
in_lambda=None,
out_lambda=None,
)

var in_vals = InputTensor[static_spec=val_spec].rand()
var out_vals = OutputTensor[static_spec=val_spec].rand()
var out_idxs = OutputTensor[static_spec=idx_spec].rand()

var cpu_ctx = DeviceContext(api="cpu")

@parameter
@always_inline
fn bench_cpu(mut b: Bencher) raises:
@parameter
@always_inline
fn run_bench() raises:
TopK.execute[K=K, target="cpu"](
out_vals, out_idxs, in_vals, cpu_ctx
)

b.iter[run_bench]()

var flops = ThroughputMeasure(BenchMetric.flops, els * log2_floor(K))
var elements = ThroughputMeasure(BenchMetric.elements, els)

var b = Bench()
b.bench_function[bench_cpu](BenchId("top_k_custom", "cpu"), flops, elements)

@parameter
if has_nvidia_gpu_accelerator():
var gpu_ctx = DeviceContext()

var in_vals_dev_buff = gpu_ctx.enqueue_create_buffer[val_dtype](els)
var out_vals_dev_buff = gpu_ctx.enqueue_create_buffer[val_dtype](els)
var out_idxs_dev_buff = gpu_ctx.enqueue_create_buffer[idx_dtype](els)

gpu_ctx.enqueue_copy(in_vals_dev_buff, in_vals.unsafe_ptr())

var out_vals_dev = OutputTensor[static_spec=val_spec](
out_vals_dev_buff.unsafe_ptr(), shape
)
var out_idxs_dev = OutputTensor[static_spec=idx_spec](
out_idxs_dev_buff.unsafe_ptr(), shape
)
var in_vals_dev = InputTensor[static_spec=val_spec](
in_vals_dev_buff.unsafe_ptr(), shape
)

@parameter
@always_inline
fn bench_gpu(mut b: Bencher) raises:
@parameter
@always_inline
fn kernel_launch(gpu_ctx: DeviceContext) raises:
TopK.execute[K=K, target="gpu"](
out_vals_dev, out_idxs_dev, in_vals_dev, gpu_ctx
)

b.iter_custom[kernel_launch](gpu_ctx)

b.bench_function[bench_gpu](
BenchId("top_k_custom", "gpu"), flops, elements
)
_ = in_vals_dev_buff
_ = out_vals_dev_buff
_ = out_idxs_dev_buff

b.config.verbose_metric_names = False
print(b)

_ = in_vals
_ = out_vals
_ = out_idxs


def main():
top_k()
68 changes: 68 additions & 0 deletions custom-ops-ai-applications/fused_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

import os
from pathlib import Path

import numpy as np
from max.driver import CPU, Accelerator, Tensor, accelerator_count
from max.dtype import DType
from max.engine.api import InferenceSession
from max.graph import Graph, TensorType, ops


def main():
path = Path(__file__).parent / "operations.mojopkg"

dtype = DType.float32
N = 32
D = 32
BD = 8
BN = 16
with Graph(
"fused_attention",
input_types=[
TensorType(dtype, shape=[N, D]),
TensorType(dtype, shape=[N, D]),
TensorType(dtype, shape=[N, D]),
],
) as graph:
q, k, v, *_ = graph.inputs
results = ops.custom(
name="fused_attention_custom",
parameters={"N": N, "D": D, "BD": BD, "BN": BN},
values=[q, k, v],
out_types=[TensorType(dtype, shape=[N, D])],
)
graph.output(*results)

# Place the graph on a GPU, if available. Fall back to CPU if not.
device = CPU() if accelerator_count() == 0 else Accelerator()

# Set up an inference session for running the graph.
session = InferenceSession(devices=[device], custom_extensions=path)

# Compile the graph.
model = session.load(graph)

np.random.seed(123)
Q = Tensor.from_numpy(np.random.randn(N, D).astype("f")).to(device)
K = Tensor.from_numpy(np.random.randn(N, D).astype("f")).to(device)
V = Tensor.from_numpy(np.random.randn(N, D).astype("f")).to(device)

output = model.execute(Q, K, V)
print(output)


if __name__ == "__main__":
main()
18 changes: 18 additions & 0 deletions custom-ops-ai-applications/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
version: 1.0
long_title: "Custom Operations: Applications in AI Models"
short_title: "Custom Operations: Applications in AI Models"
author: "Brad Larson"
author_image: "author/bradlarson.jpg"
author_url: "https://www.linkedin.com/in/brad-larson-3549a5291/"
github_repo: "https://github.com/modular/max-recipes/tree/main/custom-ops-ai-applications"
date: "23-02-2025"
difficulty: "advanced"
tags:
- max-graph
- gpu-programming
- mojo

tasks:
- magic run top_k
- magic run fused_attention
- magic run benchmarks
12 changes: 12 additions & 0 deletions custom-ops-ai-applications/operations/__init__.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
Loading

0 comments on commit 1387c37

Please sign in to comment.