diff --git a/src/main/scala/gemmini/AccumulatorScale.scala b/src/main/scala/gemmini/AccumulatorScale.scala index 2541f9b1..2bed3702 100644 --- a/src/main/scala/gemmini/AccumulatorScale.scala +++ b/src/main/scala/gemmini/AccumulatorScale.scala @@ -386,22 +386,35 @@ object AccumulatorScale { } def iexp[T <: Data](q: T, qln2: T, qln2_inv: T, qb: T, qc: T)(implicit ev: Arithmetic[T]): T = { + // import ev._ + + // val zero = q.zero + // def neg(x: T) = zero-x + + // // qln2_inv needs scale to be 1 / (2 ** 16) / S + // // qln2_inv / S / (2 ** 16) = 1 / ln2 + // // q * qln2_inv = x / S / ln2 * S * (2 ** 16) = x / ln2 * (2 ** 16) + // val neg_q_iexp = neg(q) + // val z_iexp = (neg_q_iexp * qln2_inv).asUInt.do_>>(16).asTypeOf(q) // q is non-positive + // val z_iexp_saturated = Wire(z_iexp.cloneType) + // z_iexp_saturated := Mux((5 until 16).map(z_iexp.asUInt(_)).reduce(_ | _), 32.S, z_iexp) + // val qp_iexp = q.mac(z_iexp, qln2).withWidthOf(q) + // val q_poly_iexp = qc.mac(qp_iexp + qb, qp_iexp + qb).withWidthOf(q) + // // we dont want a rounding shift + // // TODO: z overflow + // (q_poly_iexp.asUInt.do_>>(z_iexp_saturated.asUInt)).asTypeOf(q) + import ev._ val zero = q.zero + val one = q.identity def neg(x: T) = zero-x - // qln2_inv needs scale to be 1 / (2 ** 16) / S - // qln2_inv / S / (2 ** 16) = 1 / ln2 - // q * qln2_inv = x / S / ln2 * S * (2 ** 16) = x / ln2 * (2 ** 16) - val neg_q_iexp = neg(q) - val z_iexp = (neg_q_iexp * qln2_inv).asUInt.do_>>(16).asTypeOf(q) // q is non-positive - val z_iexp_saturated = Wire(z_iexp.cloneType) - z_iexp_saturated := Mux((5 until 16).map(z_iexp.asUInt(_)).reduce(_ | _), 32.S, z_iexp) - val qp_iexp = q.mac(z_iexp, qln2).withWidthOf(q) - val q_poly_iexp = qc.mac(qp_iexp + qb, qp_iexp + qb).withWidthOf(q) - // we dont want a rounding shift - // TODO: z overflow - (q_poly_iexp.asUInt.do_>>(z_iexp_saturated.asUInt)).asTypeOf(q) + val q_sign = Mux(q.zero > q, neg(one), one) + val q_abs = Mux(q.zero > q, neg(q), q) + val q_clipped = Mux(q_abs > neg(qb), neg(qb), q_abs) + val q_poly = qc.mac(q_clipped + qb, q_clipped + qb).withWidthOf(q) + val q_erf = (q_sign * q_poly).withWidthOf(q) + (q * (q_erf + qc)).withWidthOf(q) }} diff --git a/src/main/scala/gemmini/Scratchpad.scala b/src/main/scala/gemmini/Scratchpad.scala index cdd63062..9725d162 100644 --- a/src/main/scala/gemmini/Scratchpad.scala +++ b/src/main/scala/gemmini/Scratchpad.scala @@ -251,9 +251,9 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // From acc are ordered val write_norm_q = Module(new Queue(new ScratchpadMemWriteRequest(local_addr_t, accType.getWidth, acc_scale_t_bits), spad_read_delay+2)) val write_scale_q = Module(new Queue(new ScratchpadMemWriteRequest(local_addr_t, accType.getWidth, acc_scale_t_bits), spad_read_delay+2)) - val write_issue_q = Module(new Queue(new ScratchpadMemWriteRequest(local_addr_t, accType.getWidth, acc_scale_t_bits), spad_read_delay+1, pipe=true)) - val read_issue_q = Module(new Queue(new ScratchpadMemReadRequest(local_addr_t, mvin_scale_t_bits), spad_read_delay+1, pipe=true)) // TODO can't this just be a normal queue? - + val write_issue_q = Module(new Queue(new ScratchpadMemWriteRequest(local_addr_t, accType.getWidth, acc_scale_t_bits), spad_read_delay+1, pipe=true, flow=true)) + val read_issue_q = Module(new Queue(new ScratchpadMemReadRequest(local_addr_t, mvin_scale_t_bits), spad_read_delay+1, pipe=true)) + write_dispatch_q.ready := false.B write_norm_q.io.enq.valid := false.B