From cb6a4e0a3136e90c6c872c8e8f551eac30865ede Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Wed, 3 Jul 2024 17:28:58 -0700 Subject: [PATCH] add CircuitMPS and CircuitPermMPS sampling tests --- tests/test_tensor/test_circuit.py | 43 ++++++++++++++++++++++++----- tests/test_tensor/test_tensor_1d.py | 10 +++++++ 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/tests/test_tensor/test_circuit.py b/tests/test_tensor/test_circuit.py index 3eb98f3b..a8a0ca57 100644 --- a/tests/test_tensor/test_circuit.py +++ b/tests/test_tensor/test_circuit.py @@ -151,13 +151,6 @@ def test_from_qsim(self): qc = qtn.Circuit.from_qsim_str(qsim) assert (qc.psi.H & qc.psi) ^ all == pytest.approx(1.0) - def test_from_qsim_mps_swapsplit(self): - G = rand_reg_graph(reg=3, n=18, seed=42) - qsim = graph_to_qsim(G) - qc = qtn.CircuitMPS.from_qsim_str(qsim) - assert len(qc.psi.tensors) == 18 - assert (qc.psi.H & qc.psi) ^ all == pytest.approx(1.0) - def test_from_openqasm2(self): qc = qtn.Circuit.from_openqasm2_str(example_openqasm2_qft()) assert (qc.psi.H & qc.psi) ^ all == pytest.approx(1.0) @@ -641,6 +634,15 @@ def test_multi_controlled_circuit(self): (b,) = circ.sample(1, group_size=3) assert b[N - 2] == "0" + +class TestCircuitMPS: + def test_from_qsim_mps_swapsplit(self): + G = rand_reg_graph(reg=3, n=18, seed=42) + qsim = graph_to_qsim(G) + qc = qtn.CircuitMPS.from_qsim_str(qsim) + assert len(qc.psi.tensors) == 18 + assert (qc.psi.H & qc.psi) ^ all == pytest.approx(1.0) + def test_multi_controlled_mps_circuit(self): N = 10 rng = np.random.default_rng(42) @@ -681,6 +683,33 @@ def test_multi_controlled_mps_circuit(self): assert mps.norm() == pytest.approx(1.0) assert mps.distance_normalized(psi_lazy) < 1e-6 + def test_mps_sampling(self): + N = 6 + circ = qtn.CircuitMPS(N) + circ.h(3) + circ.cx(3, 2) + circ.cx(2, 1) + circ.cx(1, 0) + circ.cx(0, 5) + circ.cx(5, 4) + circ.x(4) + for x in circ.sample(10): + assert x in {"000010", "111101"} + + def test_permmps_sampling(self): + N = 6 + circ = qtn.CircuitPermMPS(N) + circ.h(3) + circ.cx(3, 2) + circ.cx(2, 1) + circ.cx(1, 0) + circ.cx(0, 5) + circ.cx(5, 4) + circ.x(4) + assert circ.qubits != tuple(range(N)) + for x in circ.sample(10): + assert x in {"000010", "111101"} + class TestCircuitGen: @pytest.mark.parametrize( diff --git a/tests/test_tensor/test_tensor_1d.py b/tests/test_tensor/test_tensor_1d.py index 3e96df75..296d5dc3 100644 --- a/tests/test_tensor/test_tensor_1d.py +++ b/tests/test_tensor/test_tensor_1d.py @@ -757,6 +757,16 @@ def test_gate_non_local(self, where, phys_dim): psi.gate(G, where, contract=False) ) == pytest.approx(0.0, abs=1e-6) + def test_sample_configuration(self): + psi = qtn.MPS_rand_state(10, 7) + config, omega = psi.sample_configuration() + assert len(config) == 10 + assert abs( + psi.isel( + {psi.site_ind(i): xi for i, xi in enumerate(config)} + ).contract() + ) ** 2 == pytest.approx(omega) + class TestMatrixProductOperator: @pytest.mark.parametrize("cyclic", [False, True])