Skip to content

Commit

Permalink
add CircuitMPS and CircuitPermMPS sampling tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Jul 4, 2024
1 parent 268eb88 commit cb6a4e0
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 7 deletions.
43 changes: 36 additions & 7 deletions tests/test_tensor/test_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,6 @@ def test_from_qsim(self):
qc = qtn.Circuit.from_qsim_str(qsim)
assert (qc.psi.H & qc.psi) ^ all == pytest.approx(1.0)

def test_from_qsim_mps_swapsplit(self):
G = rand_reg_graph(reg=3, n=18, seed=42)
qsim = graph_to_qsim(G)
qc = qtn.CircuitMPS.from_qsim_str(qsim)
assert len(qc.psi.tensors) == 18
assert (qc.psi.H & qc.psi) ^ all == pytest.approx(1.0)

def test_from_openqasm2(self):
qc = qtn.Circuit.from_openqasm2_str(example_openqasm2_qft())
assert (qc.psi.H & qc.psi) ^ all == pytest.approx(1.0)
Expand Down Expand Up @@ -641,6 +634,15 @@ def test_multi_controlled_circuit(self):
(b,) = circ.sample(1, group_size=3)
assert b[N - 2] == "0"


class TestCircuitMPS:
def test_from_qsim_mps_swapsplit(self):
G = rand_reg_graph(reg=3, n=18, seed=42)
qsim = graph_to_qsim(G)
qc = qtn.CircuitMPS.from_qsim_str(qsim)
assert len(qc.psi.tensors) == 18
assert (qc.psi.H & qc.psi) ^ all == pytest.approx(1.0)

def test_multi_controlled_mps_circuit(self):
N = 10
rng = np.random.default_rng(42)
Expand Down Expand Up @@ -681,6 +683,33 @@ def test_multi_controlled_mps_circuit(self):
assert mps.norm() == pytest.approx(1.0)
assert mps.distance_normalized(psi_lazy) < 1e-6

def test_mps_sampling(self):
N = 6
circ = qtn.CircuitMPS(N)
circ.h(3)
circ.cx(3, 2)
circ.cx(2, 1)
circ.cx(1, 0)
circ.cx(0, 5)
circ.cx(5, 4)
circ.x(4)
for x in circ.sample(10):
assert x in {"000010", "111101"}

def test_permmps_sampling(self):
N = 6
circ = qtn.CircuitPermMPS(N)
circ.h(3)
circ.cx(3, 2)
circ.cx(2, 1)
circ.cx(1, 0)
circ.cx(0, 5)
circ.cx(5, 4)
circ.x(4)
assert circ.qubits != tuple(range(N))
for x in circ.sample(10):
assert x in {"000010", "111101"}


class TestCircuitGen:
@pytest.mark.parametrize(
Expand Down
10 changes: 10 additions & 0 deletions tests/test_tensor/test_tensor_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,16 @@ def test_gate_non_local(self, where, phys_dim):
psi.gate(G, where, contract=False)
) == pytest.approx(0.0, abs=1e-6)

def test_sample_configuration(self):
psi = qtn.MPS_rand_state(10, 7)
config, omega = psi.sample_configuration()
assert len(config) == 10
assert abs(
psi.isel(
{psi.site_ind(i): xi for i, xi in enumerate(config)}
).contract()
) ** 2 == pytest.approx(omega)


class TestMatrixProductOperator:
@pytest.mark.parametrize("cyclic", [False, True])
Expand Down

0 comments on commit cb6a4e0

Please sign in to comment.