Skip to content

Commit

Permalink
update loop tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
richardyrh committed Jan 28, 2025
1 parent 40498de commit 041342d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
8 changes: 7 additions & 1 deletion src/main/scala/gemmini/Controller.scala
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t),
has_training_convs, has_max_pool, has_first_layer_optimizations, has_dw_convs) }

val (loop_cmd, loop_matmul_unroller_busy) = withClock (gated_clock) { LoopMatmul(conv_cmd, reservation_station.io.matmul_ld_completed, reservation_station.io.matmul_st_completed, reservation_station.io.matmul_ex_completed,
val (loop_cmd, loop_matmul_unroller_busy, loop_completed) = withClock (gated_clock) { LoopMatmul(conv_cmd, reservation_station.io.matmul_ld_completed, reservation_station.io.matmul_st_completed, reservation_station.io.matmul_ex_completed,
meshRows*tileRows, coreMaxAddrBits, reservation_station_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries,
inputType.getWidth, accType.getWidth, dma_maxbytes, new MvinRs2(mvin_rows_bits, mvin_cols_bits, local_addr_t),
new PreloadRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new PreloadRs(mvout_rows_bits, mvout_cols_bits, local_addr_t),
Expand All @@ -276,6 +276,12 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
reservation_station.io.alloc.valid := false.B
reservation_station.io.alloc.bits := unrolled_cmd.bits

val completion_io = IO(new Bundle {
val completed = Output(loop_completed.cloneType)
})

completion_io.completed := loop_completed

/*
//-------------------------------------------------------------------------
// finish muxing control signals to rob (risc) or tiler (cisc)
Expand Down
8 changes: 6 additions & 2 deletions src/main/scala/gemmini/LoopMatmul.scala
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size
val st_completed = Input(UInt(log2Up(reservation_station_size+1).W))
val ex_completed = Input(UInt(log2Up(reservation_station_size+1).W))
val busy = Output(Bool())
val completed = Output(Vec(2, Bool()))
})

// Create states
Expand Down Expand Up @@ -936,6 +937,8 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size

io.busy := cmd.valid || loop_configured

io.completed := 0.U.asTypeOf(io.completed.cloneType)

// Create ld arbiters
val ldab_arb = Module(new WeightedArbiter(new RoCCCommand(), maxWeightA=255, staticWeightAEnabled=true)) // TODO magic numbers
ldab_arb.io.inA <> ldA.io.cmd
Expand Down Expand Up @@ -1282,6 +1285,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size

when (head_loop.running && head_loop.all_completed()) {
head_loop.reset()
io.completed(head_loop_id) := true.B
head_loop_id := ~head_loop_id
}

Expand All @@ -1302,15 +1306,15 @@ object LoopMatmul {
max_addr: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, dma_max_bytes: Int,
mvin_rs2_t: MvinRs2, preload_rs1_t: PreloadRs, preload_rs2_t: PreloadRs,
compute_rs1_t: ComputeRs, compute_rs2_t: ComputeRs, mvout_spad_rs1_t: MvoutSpadRs1, mvout_rs2_t: MvoutRs2)
(implicit p: Parameters): (DecoupledIO[GemminiCmd], Bool) = {
(implicit p: Parameters): (DecoupledIO[GemminiCmd], Bool, Vec[Bool]) = {
val mod = Module(new LoopMatmul(block_size, coreMaxAddrBits, rob_size, max_lds, max_exs, max_sts,
max_addr, max_acc_addr, input_w, acc_w, dma_max_bytes,
mvin_rs2_t, preload_rs1_t, preload_rs2_t, compute_rs1_t, compute_rs2_t, mvout_spad_rs1_t, mvout_rs2_t))
mod.io.in <> in
mod.io.ld_completed := ld_completed
mod.io.st_completed := st_completed
mod.io.ex_completed := ex_completed
(mod.io.out, mod.io.busy)
(mod.io.out, mod.io.busy, mod.io.completed)
}

def castDramOffset(dram_offset: UInt): UInt = {
Expand Down

0 comments on commit 041342d

Please sign in to comment.