-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathfragment_mma_load_a.py
116 lines (97 loc) · 3.83 KB
/
fragment_mma_load_a.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang.language as T
from typing import Literal, Callable
from tvm import DataType
from tvm.tir import IndexMap
from tilelang.intrinsics.utils import get_mma_micro_size
def make_mma_load_base_layout(dtype: str = "float16",
matrix: Literal["A", "B"] = "A",
transposed: bool = False) -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
dtype : str
The data type of the matrix.
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.intrinsics.mma_layout import (
shared_16x16_to_mma_32x8_layout_sr,
shared_16x16_to_mma_32x8_layout_rs,
shared_16x32_to_mma_32x16_layout,
shared_32x16_to_mma_32x16_layout,
)
assert matrix in ["A", "B"], "matrix should be either A or B"
dtype_bits = DataType(dtype).bits
assert transposed is False, "transposed is not supported yet"
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
transform_func_sr: Callable = None
transform_func_rs: Callable = None
if dtype_bits == 16:
transform_func_sr = shared_16x16_to_mma_32x8_layout_sr
transform_func_rs = shared_16x16_to_mma_32x8_layout_rs
elif dtype_bits == 8:
transform_func_sr = shared_16x32_to_mma_32x16_layout
transform_func_rs = shared_32x16_to_mma_32x16_layout
else:
raise ValueError(f"Unsupported dtype {dtype}")
is_sr_conditions = [False]
is_sr_conditions.append(matrix == "A" and not transposed)
is_sr_conditions.append(matrix == "B" and transposed)
is_sr_axis_order = any(is_sr_conditions)
transform_func: Callable = transform_func_sr if is_sr_axis_order else transform_func_rs
micro_size_s, _, micro_size_r = get_mma_micro_size(dtype)
transform_func = transform_func
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
lane_id, _ = inverse_mma_load_layout.map_indices([i, j])
return lane_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
_, local_id = inverse_mma_load_layout.map_indices([i, j])
return local_id
base_fragment = T.Fragment(
[micro_size_r, micro_size_s],
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
return base_fragment
block_rows = 2
block_cols = 2
warp_rows = 4
warp_cols = 4
chunk = 2
from tilelang.tools import plot_layout
# ldmatrix layout 16x16
base_layout = make_mma_load_base_layout(dtype="float16", matrix="A", transposed=False)
print(base_layout)
plot_layout(base_layout, name="base_layout")
# # warp layout 32x16
# warp_layout = base_layout.repeat([block_rows, 1],
# repeat_on_thread=True).replicate(block_cols)
# print(warp_layout)
# plot_layout(warp_layout, name="warp_layout")
# # block layout 128x32
# block_layout = warp_layout.repeat([warp_rows, chunk], repeat_on_thread=False, lower_dim_first=False)
# plot_layout(block_layout, name="block_layout")