-
Notifications
You must be signed in to change notification settings - Fork 367
/
Copy pathtest_core.py
158 lines (127 loc) · 7.88 KB
/
test_core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import random
import torch
from typing import Tuple
import deep_gemm
from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
def construct(m: int, k: int, n: int) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
ref_out = x @ y.t()
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y)
# Transpose earlier so that the testing will not trigger transposing kernels
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
return x_fp8, y_fp8, out, ref_out
def construct_grouped(num_groups: int, m: int, k: int, n: int, is_masked: bool) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16)
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
out = torch.empty((num_groups, m, n), device='cuda', dtype=torch.bfloat16)
ref_out = torch.einsum('gmk,gnk->gmn', x, y)
assert m % 4 == 0, f'TMA alignment error: {m}'
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
for i in range(num_groups):
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
# For non-masked input, we must merge the group and M dims
if not is_masked:
x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1])
out, ref_out = out.view(-1, n), ref_out.view(-1, n)
# Transpose earlier so that the testing will not trigger transposing kernels
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
return x_fp8, y_fp8, out, ref_out
def test_gemm() -> None:
print('Testing GEMM:')
for m in (64, 128, 4096):
for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
diff = calc_diff(out, ref_out)
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
# noinspection PyShadowingNames
def test_func():
# Construct new tensors every time to avoid L2 cache acceleration
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
print()
def test_m_grouped_gemm_contiguous() -> None:
print('Testing grouped contiguous GEMM:')
for num_groups, m, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)):
# TODO: make a stronger test
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
diff = calc_diff(out, ref_out)
assert diff < 0.001, f'm={m * num_groups}, {k=}, {n=}, {diff:.5f}'
# noinspection PyShadowingNames
def test_func():
# Construct new tensors every time to avoid L2 cache acceleration
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
print()
def test_m_grouped_gemm_masked() -> None:
print('Testing grouped masked GEMM:')
for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
for k, n in ((7168, 4096), (2048, 7168), ):
# Test correctness
masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384)))
for i in range(10):
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
for j in range(num_groups):
masked_m[j] = random.choice(masked_m_candidates)
expected_m = min(int(masked_m.float().mean()) + 1, m)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)
for j in range(num_groups):
diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()])
assert diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
# noinspection PyShadowingNames
def test_func():
# Construct new tensors every time to avoid L2 cache acceleration
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, m)
# Test performance with fixed shapes
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
print()
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.manual_seed(0)
random.seed(0)
print('Library path:')
print(f' > {deep_gemm.__path__}\n')
test_gemm()
test_m_grouped_gemm_contiguous()
test_m_grouped_gemm_masked()