Skip to content

Commit

Permalink
Working for identical sizes
Browse files Browse the repository at this point in the history
Signed-off-by: ElizaWszola <[email protected]>
  • Loading branch information
ElizaWszola committed Dec 11, 2024
1 parent f4a6788 commit 1ba5104
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 111 deletions.
167 changes: 84 additions & 83 deletions csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,31 @@ struct ItemDeleter {
}
};

template <typename T>
cutlass::platform::unique_ptr<T, ItemDeleter<T>> make_device_ptr(
std::vector<T>& data_host) {
T* data_device;
int count = data_host.size();
cudaMalloc(&data_device, count * sizeof(T));
cudaMemcpy(data_device, data_host.data(), count * sizeof(T),
cudaMemcpyHostToDevice);
return cutlass::platform::unique_ptr<T, ItemDeleter<T>>(data_device);
}

///////////////
template <class TupType, size_t... I>
void print(const TupType& _tup, std::index_sequence<I...>) {
std::cout << "(";
(..., (std::cout << (I == 0 ? "" : ", ") << std::get<I>(_tup)));
std::cout << ")\n";
}

template <class... T>
void print(const std::tuple<T...>& _tup) {
print(_tup, std::make_index_sequence<sizeof...(T)>());
}
////////////

template <typename Gemm, typename... EpilogueArgs>
void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
Expand All @@ -142,46 +167,67 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
using ElementAcc = float;

int groups = problem_sizes.size(0);
std::vector<ElementAB*> a_ptrs_host(groups);
std::vector<ElementAB*> b_ptrs_host(groups);
std::vector<ElementC*> c_ptrs_host(groups);
std::vector<const ElementAB*> a_ptrs_host(groups);
std::vector<const ElementAB*> b_ptrs_host(groups);
std::vector<const ElementC*> c_ptrs_host(groups);
std::vector<ElementC*> d_ptrs_host(groups);

for (int g = 0; g < groups; ++g) {
a_ptrs_host.at(g) =
static_cast<ElementAB*>(a.data_ptr()) + a_offsets[g].item<int32_t>();
b_ptrs_host.at(g) =
static_cast<ElementAB*>(b.data_ptr()) + b_offsets[g].item<int32_t>();
c_ptrs_host.at(g) =
static_cast<ElementC*>(out.data_ptr()) + out_offsets[g].item<int32_t>();
a_ptrs_host.at(g) = static_cast<const ElementAB*>(a.data_ptr()) +
a_offsets[g].item<int32_t>();
b_ptrs_host.at(g) = static_cast<const ElementAB*>(b.data_ptr()) +
b_offsets[g].item<int32_t>();
c_ptrs_host.at(g) = static_cast<const ElementC*>(out.data_ptr()) +
out_offsets[g].item<int32_t>();
d_ptrs_host.at(g) =
static_cast<ElementC*>(out.data_ptr()) + out_offsets[g].item<int32_t>();
printf("%d %d %d\n", a_offsets[g].item<int32_t>(),
printf("off: %d %d %d\n", a_offsets[g].item<int32_t>(),
b_offsets[g].item<int32_t>(), out_offsets[g].item<int32_t>());
}

using GemmKernel = typename Gemm::GemmKernel;

using StrideA = typename GemmKernel::InternalStrideA;
using StrideB = typename GemmKernel::InternalStrideB;
using StrideC = typename GemmKernel::InternalStrideC;
// using StrideD = typename GemmKernel::InternalStrideD;
// using StrideA = typename GemmKernel::InternalStrideA;
// using StrideB = typename GemmKernel::InternalStrideB;
// using StrideC = typename GemmKernel::InternalStrideC;
// // using StrideD = typename GemmKernel::InternalStrideD;

std::vector<StrideA> a_stride_host(groups);
std::vector<StrideB> b_stride_host(groups);
std::vector<StrideC> c_stride_host(groups);
int64_t lda = a.stride(0);
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);

for (int g = 0; g < groups; ++g) {
int32_t m = problem_sizes[g][0].item<int32_t>();
int32_t n = problem_sizes[g][1].item<int32_t>();
int32_t k = problem_sizes[g][2].item<int32_t>();
a_stride_host[g] = StrideA{k, cute::Int<1>{}, cute::Int<0>{}}; // m x k,
// row
b_stride_host[g] = StrideB{k, cute::Int<1>{}, cute::Int<0>{}}; // k x n,
// col
c_stride_host[g] = StrideC{n, cute::Int<1>{}, cute::Int<0>{}}; // m x n,
// row
}
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
using StrideC =
typename GemmKernel::InternalStrideC; // typename Gemm::StrideC;

// StrideA a_stride{lda, Int<1>{}, Int<0>{}};
// StrideB b_stride{ldb, Int<1>{}, Int<0>{}};
// StrideC c_stride{ldc, Int<1>{}, Int<0>{}};

std::vector<StrideA> a_stride_host(groups, StrideA{lda, Int<1>{}, Int<0>{}});
std::vector<StrideB> b_stride_host(groups, StrideB{ldb, Int<1>{}, Int<0>{}});
std::vector<StrideC> c_stride_host(groups, StrideC{ldc, Int<1>{}, Int<0>{}});

printf("a: ");
print(a_stride_host[0]);
printf("\nb: ");
print(b_stride_host[0]);
printf("\nc: ");
print(c_stride_host[0]);
printf("\n");

// for (int g = 0; g < groups; ++g) {
// int32_t m = problem_sizes[g][0].item<int32_t>();
// int32_t n = problem_sizes[g][1].item<int32_t>();
// int32_t k = problem_sizes[g][2].item<int32_t>();
// a_stride_host[g] = StrideA{k, cute::Int<1>{}, cute::Int<0>{}}; // m x k,
// // row
// b_stride_host[g] = StrideB{k, cute::Int<1>{}, cute::Int<0>{}}; // k x n,
// // col
// c_stride_host[g] = StrideC{n, cute::Int<1>{}, cute::Int<0>{}}; // m x n,
// // row
// }

cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with
Expand All @@ -200,16 +246,11 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
int32_t n = problem_sizes[g][1].item<int32_t>();
int32_t k = problem_sizes[g][2].item<int32_t>();
problem_sizes_host.push_back({m, n, k});
printf("mnk: %d, %d, %d\n", m, n, k);
}

SingleProblemShape* problem_sizes_device;
int32_t problem_sizes_size = groups * sizeof(SingleProblemShape);
cudaMalloc(&problem_sizes_device, problem_sizes_size);
cudaMemcpy(problem_sizes_device, problem_sizes_host.data(),
groups * sizeof(SingleProblemShape), cudaMemcpyHostToDevice);
cutlass::platform::unique_ptr<SingleProblemShape,
ItemDeleter<SingleProblemShape>>
problem_sizes_ptr(problem_sizes_device);
auto problem_sizes_ptr =
make_device_ptr<SingleProblemShape>(problem_sizes_host);
ProblemShape prob_shape{groups, problem_sizes_ptr.get(),
problem_sizes_host.data()};

Expand All @@ -221,54 +262,14 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
// cudaMemcpy(static_cast<ElementAB*>(a.data_ptr()), a_host_print, numel*
// sizeof(ElementAB), cudaMemcpyHostToDevice); cudaFree(a_host_print);

const ElementAB** a_ptrs_device;
cudaMalloc(&a_ptrs_device, groups * sizeof(ElementAB*));
cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups * sizeof(ElementAB*),
cudaMemcpyHostToDevice);
cutlass::platform::unique_ptr<const ElementAB*, ItemDeleter<const ElementAB*>>
a_ptrs_ptr(a_ptrs_device);

const ElementAB** b_ptrs_device;
cudaMalloc(&b_ptrs_device, groups * sizeof(ElementAB*));
cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups * sizeof(ElementAB*),
cudaMemcpyHostToDevice);
cutlass::platform::unique_ptr<const ElementAB*, ItemDeleter<const ElementAB*>>
b_ptrs_ptr(b_ptrs_device);

const ElementC** c_ptrs_device;
cudaMalloc(&c_ptrs_device, groups * sizeof(ElementC*));
cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups * sizeof(ElementC*),
cudaMemcpyHostToDevice);
cutlass::platform::unique_ptr<const ElementC*, ItemDeleter<const ElementC*>>
c_ptrs_ptr(c_ptrs_device);

ElementC** d_ptrs_device;
cudaMalloc(&d_ptrs_device, groups * sizeof(ElementC*));
cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups * sizeof(ElementC*),
cudaMemcpyHostToDevice);
cutlass::platform::unique_ptr<ElementC*, ItemDeleter<ElementC*>> d_ptrs_ptr(
d_ptrs_device);

StrideA* a_stride_device;
cudaMalloc(&a_stride_device, groups * sizeof(StrideA));
cudaMemcpy(a_stride_device, a_stride_host.data(), groups * sizeof(StrideA),
cudaMemcpyHostToDevice);
cutlass::platform::unique_ptr<StrideA, ItemDeleter<StrideA>> a_stride_ptr(
a_stride_device);
auto a_ptrs_ptr = make_device_ptr<const ElementAB*>(a_ptrs_host);
auto b_ptrs_ptr = make_device_ptr<const ElementAB*>(b_ptrs_host);
auto c_ptrs_ptr = make_device_ptr<const ElementC*>(c_ptrs_host);
auto d_ptrs_ptr = make_device_ptr<ElementC*>(d_ptrs_host);

StrideB* b_stride_device;
cudaMalloc(&b_stride_device, groups * sizeof(StrideB));
cudaMemcpy(b_stride_device, b_stride_host.data(), groups * sizeof(StrideB),
cudaMemcpyHostToDevice);
cutlass::platform::unique_ptr<StrideB, ItemDeleter<StrideB>> b_stride_ptr(
b_stride_device);

StrideC* c_stride_device;
cudaMalloc(&c_stride_device, groups * sizeof(StrideC));
cudaMemcpy(c_stride_device, c_stride_host.data(), groups * sizeof(StrideC),
cudaMemcpyHostToDevice);
cutlass::platform::unique_ptr<StrideC, ItemDeleter<StrideC>> c_stride_ptr(
c_stride_device);
auto a_stride_ptr = make_device_ptr<StrideA>(a_stride_host);
auto b_stride_ptr = make_device_ptr<StrideB>(b_stride_host);
auto c_stride_ptr = make_device_ptr<StrideC>(c_stride_host);

typename GemmKernel::MainloopArguments mainloop_args{
a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(),
Expand Down
62 changes: 34 additions & 28 deletions tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def baseline_scaled_mm(a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
print(a.shape, b.shape, scale_a.shape, scale_b.shape)
output = (scale_a * (scale_b * (torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
if bias is not None:
Expand Down Expand Up @@ -458,9 +459,9 @@ def test_cutlass_support_opcheck():

# TODO fix scales
@pytest.mark.parametrize("m,n,k", [(2048, 2048, 2048)])
@pytest.mark.parametrize("num_groups", [10])
@pytest.mark.parametrize("per_act_token", [False]) # [True, False])
@pytest.mark.parametrize("per_out_ch", [True]) # [True, False])
@pytest.mark.parametrize("num_groups", [1, 4, 10])
@pytest.mark.parametrize("per_act_token", [True, False]) # [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False]) # [True, False])
@pytest.mark.parametrize("use_bias", [False]) # [True, False])
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
Expand All @@ -486,7 +487,7 @@ def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int,
k = alignment * random.randint(1, 64)
for g in range(num_groups):
tot_a += m
tot_b += k
tot_b += n
tot_c += m
print(m, n, k)
offsets_a[g] = g * m * k
Expand All @@ -497,7 +498,13 @@ def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int,
problem_sizes[g][2] = k

a = to_fp8(torch.randn((tot_a, k), device=device))
b = to_fp8(torch.randn((tot_b, n), device=device).t())

b_float = torch.randn((tot_b, k), device=device)
# for g in range(num_groups):
# b_float[g * k:(g + 1) * k] = torch.full((k, n), g + 1)
# print(b_float)

b = to_fp8(b_float.t())
c = torch.zeros((tot_c, n), device=device).to(out_dtype)
baseline = torch.zeros((tot_c, n), device=device).to(out_dtype)

Expand All @@ -511,29 +518,19 @@ def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int,

# print(a.stride(), b.stride(), c.stride())

# m_a_scales = m if per_act_token else 1
# n_b_scales = n if per_out_ch else 1

# scale_a = (torch.randn((tot_a if per_act_token else num_groups),
# device=device,
# dtype=torch.float32))
# scale_b = (torch.randn((tot_b if per_act_token else num_groups),
# device=device,
# dtype=torch.float32))

scale_a = (torch.ones((tot_a if per_act_token else num_groups),
device=device,
dtype=torch.float32))
scale_b = (torch.ones((tot_b if per_act_token else num_groups),
device=device,
dtype=torch.float32))
scale_a = (torch.randn(((m, 1) if per_act_token else (1, 1)),
device=device,
dtype=torch.float32))
scale_b = (torch.randn(((1, n) if per_out_ch else (1, 1)),
device=device,
dtype=torch.float32))

# if use_bias:
# bias = torch.rand((n, 1), device=device, dtype=out_dtype) * 10
# else:
# bias = None

print(a)
# print(a)

# TODO strides we can get later the same way as in scaled_mm_c3x.cu
torch.ops._C.cutlass_grouped_mm(c, a, b, scale_a, scale_b, problem_sizes,
Expand All @@ -547,20 +544,29 @@ def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int,
# # print(c[2*m:3*m])
# print(torch.max(c, dim=1))
# print(torch.max(c, dim=0))
print(c)
# print(c)

for g in range(num_groups):
print(a[g * m:(g + 1) * m].shape, b[:, g * n:(g + 1) * n].shape)
baseline[g * m:(g + 1) * m] = baseline_scaled_mm(
a[g * m:(g + 1) * m],
b.t()[g * k:(g + 1) * k],
scale_a[g * m:(g + 1) * m] if per_act_token else scale_a[g],
scale_b[g * k:(g + 1) * k] if per_act_token else scale_b[g],
out_dtype, None)
b[:, g * n:(g + 1) * n],
# scale_a[g * m:(g + 1) * m] if per_act_token else scale_a[g],
# # scale_b[:, g * n:(g + 1) * n] if per_out_ch else scale_b[:, g],
# scale_b[g],
scale_a,
scale_b,
out_dtype,
None)
print(baseline[g * m:(g + 1) * m])
print(c[g * m:(g + 1) * m])
print("*")

# torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-2)
# baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, None)
# print(baseline)
# print(c)

torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-2)

# opcheck(torch.ops._C.cutlass_scaled_mm,
# (out, a, b, scale_a, scale_b, bias))

0 comments on commit 1ba5104

Please sign in to comment.