Skip to content

Commit

Permalink
full gemm test
Browse files Browse the repository at this point in the history
  • Loading branch information
rejunity committed Mar 11, 2024
1 parent bca64ac commit 6f99a79
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 23 deletions.
8 changes: 2 additions & 6 deletions src/1_58bit_mul.v
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ module systolic_array (

output wire [7:0] out
);
localparam SLICES = 2;
localparam SLICES = 1;
localparam W = 1 * SLICES;
localparam H = 4 * SLICES;

Expand Down Expand Up @@ -109,10 +109,6 @@ module systolic_array (
arg_left_sign <= 0;
arg_top <= 0;
end else begin
// arg_left_zero <= in_left_zero;
// arg_left_sign <= in_left_sign;
// arg_top <= in_top;

arg_left_zero[slice_counter*4 +: 4] <= in_left_zero;
arg_left_sign[slice_counter*4 +: 4] <= in_left_sign;
arg_top[slice_counter*8 +: 8] <= in_top;
Expand Down Expand Up @@ -146,5 +142,5 @@ module systolic_array (
end
endgenerate

assign out = out_queue[out_queue_index] >> 8;
assign out = out_queue[out_queue_index][7:0];// >> 8;
endmodule
166 changes: 149 additions & 17 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from cocotb.triggers import ClockCycles
from utils import *

def OUT(v):
# return v >> 8
return s8_to_i32(v & 255)

@cocotb.test()
async def test_1(dut):
dut._log.info("Start")
Expand All @@ -30,21 +34,21 @@ async def test_1(dut):
dut.ui_in.value = 0b01_11_01_00
dut.uio_in.value = 127

await ClockCycles(dut.clk, 5)
await ClockCycles(dut.clk, 6)
dut.ena.value = 0
await ClockCycles(dut.clk, 1)
dut.ena.value = 1

# Validate
dut._log.info("Validate")
await ClockCycles(dut.clk, 1)
assert s8_to_i32(dut.uo_out.value) == ( 1 * 127 * 6) >> 8
assert s8_to_i32(dut.uo_out.value) == OUT( 1 * 127 * 6)
await ClockCycles(dut.clk, 1)
assert s8_to_i32(dut.uo_out.value) == (-1 * 127 * 6) >> 8
assert s8_to_i32(dut.uo_out.value) == OUT(-1 * 127 * 6)
await ClockCycles(dut.clk, 1)
assert s8_to_i32(dut.uo_out.value) == ( 1 * 127 * 6) >> 8
assert s8_to_i32(dut.uo_out.value) == OUT( 1 * 127 * 6)
await ClockCycles(dut.clk, 1)
assert s8_to_i32(dut.uo_out.value) == ( 0 * 127 * 6) >> 8
assert s8_to_i32(dut.uo_out.value) == OUT( 0 * 127 * 6)

@cocotb.test()
async def test_2(dut):
Expand Down Expand Up @@ -81,7 +85,7 @@ async def test_2(dut):

for w in weights:
await ClockCycles(dut.clk, 1)
assert s8_to_i32(dut.uo_out.value) == dot(w, inputs) >> 8
assert s8_to_i32(dut.uo_out.value) == OUT(dot(w, inputs))

@cocotb.test()
async def test_3(dut):
Expand Down Expand Up @@ -118,20 +122,16 @@ async def test_3(dut):

for w in weights:
await ClockCycles(dut.clk, 1)
assert s8_to_i32(dut.uo_out.value) == dot(w, inputs) >> 8
assert s8_to_i32(dut.uo_out.value) == OUT(dot(w, inputs))

@cocotb.test()
async def test_4(dut):
random.seed(3)
N = 128
weights = random_matrix(-1, 1, (N, 4))
packed_weights = pack_weights(flatten(weights))
# inputs = [127] * N
inputs = range(N)

print (weights)
print (inputs)

dut._log.info("Start")
clock = Clock(dut.clk, 10, units="us")
cocotb.start_soon(clock.start())
Expand All @@ -144,6 +144,7 @@ async def test_4(dut):

# Compute
dut._log.info("Compute")
# for x, w in zip(inputs, packed_weights):
for x in reversed(inputs):
dut.uio_in.value = x
dut.ui_in.value = packed_weights & 255
Expand All @@ -160,12 +161,143 @@ async def test_4(dut):
# Validate
dut._log.info("Validate")

for w in transpose(weights):
print (dot(w, inputs), dot(w, inputs) >> 8, s8_to_i32(dot(w, inputs) & 255))
# for w in transpose(weights):
# print (dot(w, inputs), OUT(dot(w, inputs)), s8_to_i32(dot(w, inputs) & 255))

for w in transpose(weights):
await ClockCycles(dut.clk, 1)
print (dut.uo_out.value, int(dut.uo_out.value), s8_to_i32(dut.uo_out.value))
# print (w.shape, np.array(inputs).shape, sum(w * inputs), sum(w * inputs) >> 8)
# assert s8_to_i32(dut.uo_out.value) == (w @ np.array(inputs)) >> 8
assert s8_to_i32(dut.uo_out.value) == dot(w, inputs) >> 8
# print (dut.uo_out.value, int(dut.uo_out.value), s8_to_i32(dut.uo_out.value))
assert s8_to_i32(dut.uo_out.value) == OUT(dot(w, inputs))


@cocotb.test()
async def test_5(dut):
random.seed(3)
N = 128
weights = random_matrix(-1, 1, (4, N))
inputs = random_matrix(-127, 127, (N, 1))
expected = matrix_mul(weights, inputs)
packed_weights = pack_weights_as_u8_array(zigzag_h(weights, 4))
packed_inputs = zigzag_w(inputs, 1)

dut._log.info("Start")
clock = Clock(dut.clk, 10, units="us")
cocotb.start_soon(clock.start())

# Reset
dut._log.info("Reset")
dut.rst_n.value = 0
await ClockCycles(dut.clk, 4)
dut.rst_n.value = 1

# Compute
dut._log.info("Compute")
for x, w in zip(packed_inputs, packed_weights):
dut.uio_in.value = x
dut.ui_in.value = w
await ClockCycles(dut.clk, 1)

# Move accumulators to output queue
dut.ena.value = 0
dut.ui_in.value = 0
dut.uio_in.value = 0
await ClockCycles(dut.clk, 1)
dut.ena.value = 1

# Validate
dut._log.info("Validate")

# for w in transpose(weights):
for v in flatten(expected):
await ClockCycles(dut.clk, 1)
# print (dut.uo_out.value, int(dut.uo_out.value), s8_to_i32(dut.uo_out.value))
assert s8_to_i32(dut.uo_out.value) == OUT(v)

async def gemm(dut, weights, inputs, compute_block_width = 1, compute_block_height = 4):
# print ("W $", shape(weights))
# print ("X $", shape(inputs))

N = len(weights) // compute_block_height
M = len(inputs[0]) // compute_block_width
assert len(weights[0]) == len(inputs)
K = len(weights[0])

# zigzag_weights_unpacked = zigzag_h(weights, compute_block_height)
zigzag_weights = pack_weights_as_u8_array(zigzag_h(weights, compute_block_height))
zigzag_inputs = zigzag_w(inputs, compute_block_width)

assert len(zigzag_weights) == N * K
assert len(zigzag_inputs) == K * M

result = []
for m in range(M):
for n in range(N):
# Set inputs & compute
weights_for_compute = zigzag_weights[n*K:(n+1)*K]
inputs_for_compute = zigzag_inputs[m*K:(m+1)*K]
for x, w in zip(inputs_for_compute, weights_for_compute):
dut.uio_in.value = x
dut.ui_in.value = w
await ClockCycles(dut.clk, 1)

# Move accumulators to output queue
dut.ena.value = 0
dut.ui_in.value = 0
dut.uio_in.value = 0
await ClockCycles(dut.clk, 1)
dut.ena.value = 1

for _ in range(compute_block_width * compute_block_height):
await ClockCycles(dut.clk, 1)
result.append(s8_to_i32(dut.uo_out.value))

assert len(result) == N * M * compute_block_height * compute_block_width

# transpose result & shape as 2d array
N *= compute_block_height
M *= compute_block_width
result = [result[i+j] for j in range(N) for i in range(0, len(result), N)]
result = [result[i:i+M] for i in range(0, len(result), M)]
assert shape(result) == (N, M)
return result

@cocotb.test()
async def test_gemm(dut):
random.seed(3)
W = 1
H = 4

N = 4
K = 128
M = 3
weights = random_matrix( -1, 1, (N*H, K))
inputs = random_matrix(-127, 127, (K, M*W))
expected = matrix_mul(weights, inputs)
# print ("W = ", weights, shape(weights))
# print ("X = ", inputs, shape(inputs))
# print ("O =", expected, shape(expected))

dut._log.info("Start")
clock = Clock(dut.clk, 10, units="us")
cocotb.start_soon(clock.start())

# Reset
dut._log.info("Reset")
dut.rst_n.value = 0
await ClockCycles(dut.clk, 4)
dut.rst_n.value = 1

# Compute
dut._log.info("Compute")
dut_result = await gemm(dut, weights, inputs, W, H)

# Validate
dut._log.info("Validate")
# expected = flatten(transpose(expected))
assert shape(expected) == shape(dut_result)
print ("O =", expected, "shape =", shape(expected))
print ("O'=", [OUT(v) for v in flatten(expected)])
print ("R =", flatten(dut_result), "shape =", shape(dut_result))

for v, r in zip(flatten(expected), flatten(dut_result)):
assert OUT(v) == r

0 comments on commit 6f99a79

Please sign in to comment.