diff --git a/pytket/extensions/cutensornet/structured_state/classical.py b/pytket/extensions/cutensornet/structured_state/classical.py index d7bc9c9..474529a 100644 --- a/pytket/extensions/cutensornet/structured_state/classical.py +++ b/pytket/extensions/cutensornet/structured_state/classical.py @@ -71,14 +71,26 @@ def apply_classical_command( ) for var_id, reg_pos_list in op.expr.reg_posn.items() } - result = evaluate_clexpr(op.expr.expr, bitvar_val, regvar_val) + # Identify number of bits on each register + regvar_size = { + var_id: len(reg_pos_list) + for var_id, reg_pos_list in op.expr.reg_posn.items() + } + # Identify number of bits in output register + output_size = len(op.expr.output_posn) + result = evaluate_clexpr( + op.expr.expr, bitvar_val, regvar_val, regvar_size, output_size + ) # The result is an int in little-endian encoding. We update the # output register accordingly. for bit_pos in op.expr.output_posn: bits_dict[args[bit_pos]] = (result % 2) == 1 result = result >> 1 - assert result == 0 # All bits consumed + # If there has been overflow in the operations, error out. + # This can be detected if `result != 0` + if result != 0: + raise ValueError("Evaluation of the ClExpr resulted in overflow.") elif isinstance(op, ClassicalExpBox): the_exp = op.get_exp() @@ -99,7 +111,11 @@ def apply_classical_command( def evaluate_clexpr( - expr: ClExpr, bitvar_val: dict[int, int], regvar_val: dict[int, int] + expr: ClExpr, + bitvar_val: dict[int, int], + regvar_val: dict[int, int], + regvar_size: dict[int, int], + output_size: int, ) -> int: """Recursive evaluation of a ClExpr.""" @@ -113,7 +129,9 @@ def evaluate_clexpr( elif isinstance(arg, ClRegVar): value = regvar_val[arg.index] elif isinstance(arg, ClExpr): - value = evaluate_clexpr(arg, bitvar_val, regvar_val) + value = evaluate_clexpr( + arg, bitvar_val, regvar_val, regvar_size, output_size + ) else: raise Exception(f"Unrecognised argument type of ClExpr: {type(arg)}.") @@ -140,31 +158,39 @@ def evaluate_clexpr( result = int(args_val[0] < args_val[1]) elif expr.op == ClOp.BitNot: result = 1 - args_val[0] - # elif expr.op == ClOp.RegNot: - # result = int(args_val[0] == 0) + elif expr.op == ClOp.RegNot: # Bit-wise NOT (flip all bits) + n_bits = regvar_size[expr.args[0].index] # type: ignore + result = (2**n_bits - 1) ^ args_val[0] # XOR with all 1s bitstring elif expr.op in [ClOp.BitZero, ClOp.RegZero]: result = 0 - elif expr.op in [ClOp.BitOne, ClOp.RegOne]: + elif expr.op == ClOp.BitOne: result = 1 - # elif expr.op == ClOp.RegAdd: - # result = args_val[0] + args_val[1] - # elif expr.op == ClOp.RegSub: - # result = args_val[0] - args_val[1] - # elif expr.op == ClOp.RegMul: - # result = args_val[0] * args_val[1] - # elif expr.op == ClOp.RegPow: - # result = int(args_val[0] ** args_val[1]) + elif expr.op == ClOp.RegOne: # All 1s bitstring + n_bits = output_size + result = 2**n_bits - 1 + elif expr.op == ClOp.RegAdd: + result = args_val[0] + args_val[1] + elif expr.op == ClOp.RegSub: + if args_val[0] < args_val[1]: + raise NotImplementedError( + "Currently not supporting ClOp.RegSub where the outcome is negative." + ) + result = args_val[0] - args_val[1] + elif expr.op == ClOp.RegMul: + result = args_val[0] * args_val[1] + elif expr.op == ClOp.RegDiv: # floor(a / b) + result = args_val[0] // args_val[1] + elif expr.op == ClOp.RegPow: + result = int(args_val[0] ** args_val[1]) + elif expr.op == ClOp.RegLsh: + result = args_val[0] << args_val[1] elif expr.op == ClOp.RegRsh: result = args_val[0] >> args_val[1] # elif expr.op == ClOp.RegNeg: # result = -args_val[0] else: - # TODO: Currently not supporting ClOp's RegDiv since it does not return int, - # so I am unsure what the semantic is meant to be. - # TODO: I don't now what to do with RegNot, since input - # is not guaranteed to be 0 or 1. - # TODO: It is not clear what to do with overflow of ADD, etc. - # so I have decided to not support them for now. + # TODO: Not supporting RegNeg because I do not know if we have agreed how to + # specify signed ints. raise NotImplementedError( f"Evaluation of {expr.op} not supported in ClExpr ", "by pytket-cutensornet.", @@ -231,4 +257,5 @@ def evaluate_logic_exp(exp: ExtendedLogicExp, bits_dict: dict[Bit, bool]) -> int def from_little_endian(bitstring: list[bool]) -> int: """Obtain the integer from the little-endian encoded bitstring (i.e. bitstring [False, True] is interpreted as the integer 2).""" + # TODO: Assumes unisigned integer. What are the specs for signed integers? return sum(1 << i for i, b in enumerate(bitstring) if b) diff --git a/tests/test_structured_state_conditionals.py b/tests/test_structured_state_conditionals.py index 53e90f0..3fdaa17 100644 --- a/tests/test_structured_state_conditionals.py +++ b/tests/test_structured_state_conditionals.py @@ -10,6 +10,9 @@ Bit, if_not_bit, reg_eq, + WiredClExpr, + ClExpr, + ClOp, ) from pytket.circuit.logic_exp import BitWiseOp, create_bit_logic_exp from pytket.circuit.clexpr import wired_clexpr_from_logic_exp @@ -36,11 +39,11 @@ def test_circuit_with_clexpr_i() -> None: c = circ.add_c_register("c", 5) d = circ.add_c_register("d", 5) circ.H(0) - wexpr, args = wired_clexpr_from_logic_exp(a | b, c) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(a | b, c.to_list()) circ.add_clexpr(wexpr, args) - wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(c | b, d.to_list()) circ.add_clexpr(wexpr, args) - wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(c | b, d.to_list()) circ.add_clexpr(wexpr, args, condition=a[4]) circ.H(0) circ.Measure(Qubit(0), d[4]) @@ -66,9 +69,9 @@ def test_circuit_with_classicalexpbox_i() -> None: c = circ.add_c_register("c", 5) d = circ.add_c_register("d", 5) circ.H(0) - circ.add_classicalexpbox_register(a | b, c) # type: ignore - circ.add_classicalexpbox_register(c | b, d) # type: ignore - circ.add_classicalexpbox_register(c | b, d, condition=a[4]) # type: ignore + circ.add_classicalexpbox_register(a | b, c.to_list()) + circ.add_classicalexpbox_register(c | b, d.to_list()) + circ.add_classicalexpbox_register(c | b, d.to_list(), condition=a[4]) circ.H(0) circ.Measure(Qubit(0), d[4]) circ.H(1) @@ -93,11 +96,11 @@ def test_circuit_with_clexpr_ii() -> None: c = circ.add_c_register("c", 5) d = circ.add_c_register("d", 5) circ.H(0) - wexpr, args = wired_clexpr_from_logic_exp(a | b, c) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(a | b, c.to_list()) circ.add_clexpr(wexpr, args) - wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(c | b, d.to_list()) circ.add_clexpr(wexpr, args) - wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(c | b, d.to_list()) circ.add_clexpr(wexpr, args, condition=if_not_bit(a[4])) circ.H(0) circ.Measure(Qubit(0), d[4]) @@ -123,11 +126,9 @@ def test_circuit_with_classicalexpbox_ii() -> None: c = circ.add_c_register("c", 5) d = circ.add_c_register("d", 5) circ.H(0) - circ.add_classicalexpbox_register(a | b, c) # type: ignore - circ.add_classicalexpbox_register(c | b, d) # type: ignore - circ.add_classicalexpbox_register( - c | b, d, condition=if_not_bit(a[4]) # type: ignore - ) + circ.add_classicalexpbox_register(a | b, c.to_list()) + circ.add_classicalexpbox_register(c | b, d.to_list()) + circ.add_classicalexpbox_register(c | b, d.to_list(), condition=if_not_bit(a[4])) circ.H(0) circ.Measure(Qubit(0), d[4]) circ.H(1) @@ -160,9 +161,9 @@ def test_circuit_with_clexpr_iii() -> None: big_exp = bits[4] | bits[5] ^ bits[6] | bits[7] & bits[8] circ.H(0, condition=big_exp) - wexpr, args = wired_clexpr_from_logic_exp(a + b - d, c) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(a + b - d, c.to_list()) circ.add_clexpr(wexpr, args) - wexpr, args = wired_clexpr_from_logic_exp(a * b * d * c, e) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(a * b * d * c, e.to_list()) circ.add_clexpr(wexpr, args) with CuTensorNetHandle() as libhandle: @@ -190,8 +191,8 @@ def test_circuit_with_classicalexpbox_iii() -> None: big_exp = bits[4] | bits[5] ^ bits[6] | bits[7] & bits[8] circ.H(0, condition=big_exp) - circ.add_classicalexpbox_register(a + b - d, c) # type: ignore - circ.add_classicalexpbox_register(a * b * d * c, e) # type: ignore + circ.add_classicalexpbox_register(a + b - d, c.to_list()) + circ.add_classicalexpbox_register(a * b * d * c, e.to_list()) with CuTensorNetHandle() as libhandle: cfg = Config() @@ -268,7 +269,7 @@ def test_circuit_with_conditional_gate_iv() -> None: assert state.get_fidelity() == 1.0 -def test_pytket_qir_conditional_8() -> None: +def test_pytket_basic_conditional_i() -> None: c = Circuit(4) c.H(0) c.H(1) @@ -287,7 +288,7 @@ def test_pytket_qir_conditional_8() -> None: assert state.get_fidelity() == 1.0 -def test_pytket_qir_conditional_9() -> None: +def test_pytket_basic_conditional_ii() -> None: c = Circuit(4) c.X(0) c.Y(1) @@ -306,7 +307,7 @@ def test_pytket_qir_conditional_9() -> None: assert state.get_fidelity() == 1.0 -def test_pytket_qir_conditional_10() -> None: +def test_pytket_basic_conditional_iii_classicalexpbox() -> None: box_circ = Circuit(4) box_circ.X(0) box_circ.Y(1) @@ -315,7 +316,7 @@ def test_pytket_qir_conditional_10() -> None: box_c = box_circ.add_c_register("c", 5) box_circ.H(0) - box_circ.add_classicalexpbox_register(box_c | box_c, box_c) # type: ignore + box_circ.add_classicalexpbox_register(box_c | box_c, box_c.to_list()) cbox = CircBox(box_circ) d = Circuit(4, 5) @@ -330,7 +331,7 @@ def test_pytket_qir_conditional_10() -> None: assert state.get_fidelity() == 1.0 -def test_pytket_qir_conditional_11() -> None: +def test_pytket_basic_conditional_iii_clexpr() -> None: box_circ = Circuit(4) box_circ.X(0) box_circ.Y(1) @@ -340,7 +341,7 @@ def test_pytket_qir_conditional_11() -> None: box_circ.H(0) - wexpr, args = wired_clexpr_from_logic_exp(box_c | box_c, box_c) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(box_c | box_c, box_c.to_list()) box_circ.add_clexpr(wexpr, args) cbox = CircBox(box_circ) @@ -697,3 +698,41 @@ def test_repeat_until_success_ii_classicalexpblox() -> None: assert np.isclose(abs(global_phase), 1.0) output_state *= global_phase assert np.allclose(target_state, output_state) + + +def test_clexpr_on_regs() -> None: + """Non-exhaustive test on some ClOp on registers.""" + circ = Circuit(2) + a = circ.add_c_register("a", 5) + b = circ.add_c_register("b", 5) + c = circ.add_c_register("c", 5) + d = circ.add_c_register("d", 5) + e = circ.add_c_register("e", 5) + + w_expr_regone = WiredClExpr(ClExpr(ClOp.RegOne, []), output_posn=list(range(5))) + circ.add_clexpr(w_expr_regone, a.to_list()) # a = 0b11111 = 31 + circ.add_c_setbits([True, True, False, False, False], b.to_list()) # b = 3 + circ.add_c_setbits([False, True, False, True, False], c.to_list()) # c = 10 + circ.add_clexpr(*wired_clexpr_from_logic_exp(b | c, d.to_list())) # d = 11 + circ.add_clexpr(*wired_clexpr_from_logic_exp(a - d, e.to_list())) # e = 20 + + with CuTensorNetHandle() as libhandle: + cfg = Config() + + state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg) + assert state.is_valid() + assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) + assert state.get_fidelity() == 1.0 + + # Check the bits + bits_dict = state.get_bits() + a_bitstring = list(bits_dict[bit] for bit in a) + assert all(a_bitstring) # a = 0b11111 + b_bitstring = list(bits_dict[bit] for bit in b) + assert b_bitstring == [True, True, False, False, False] # b = 0b11000 + c_bitstring = list(bits_dict[bit] for bit in c) + assert c_bitstring == [False, True, False, True, False] # c = 0b01010 + d_bitstring = list(bits_dict[bit] for bit in d) + assert d_bitstring == [True, True, False, True, False] # d = 0b11010 + e_bitstring = list(bits_dict[bit] for bit in e) + assert e_bitstring == [False, False, True, False, True] # e = 0b00101