From f732b73122d43e925d5ef357a180e15598d65792 Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Mon, 11 Dec 2023 14:26:22 -0800 Subject: [PATCH] kraus_op: fix multisubsystem bug (#214) --- quimb/calc.py | 42 ++++++++++++++---------------------------- tests/test_calc.py | 13 +++++++++++-- 2 files changed, 25 insertions(+), 30 deletions(-) diff --git a/quimb/calc.py b/quimb/calc.py index fb9d8eec..b08b907d 100644 --- a/quimb/calc.py +++ b/quimb/calc.py @@ -168,7 +168,7 @@ def kraus_op(rho, Ek, dims=None, where=None, check=False): if norm(SEk - eye(Ek.shape[-1]), "fro") > 1e-12: raise ValueError("Did not find ``sum(E_k.H @ Ek) == 1``.") - if int(dims is None) + int(where is None) == 1: + if (dims is None) and (where is None): raise ValueError("If `dims` is specified so should `where`.") if isinstance(where, numbers.Integral): @@ -184,34 +184,20 @@ def kraus_op(rho, Ek, dims=None, where=None, check=False): kdims = tuple(dims[i] for i in where) Ek = Ek.reshape((-1,) + kdims + kdims) - rho_inds, out, Ei_inds, Ej_inds = [], [], ["K"], ["K"] - for i in range(N): - if i in where: - xi, xj = f"i{i}k", f"j{i}k" - for inds in (rho_inds, Ei_inds): - inds.append(xi) - for inds in (rho_inds, Ej_inds): - inds.append(xj) - xi, xj = f"i{i}new", f"j{i}new" - for inds in (out, Ei_inds): - inds.append(xi) - for inds in (out, Ej_inds): - inds.append(xj) - else: - xi, xj = f"i{i}", f"j{i}" - for inds in (rho_inds, out): - inds.append(xi) - inds.append(xj) - - rho_inds = tuple(sorted(rho_inds)) - out = tuple(sorted(out)) - Ei_inds = tuple(sorted(Ei_inds)) - Ej_inds = tuple(sorted(Ej_inds)) + rho_inds = ( + *(f"i*{q}" if q in where else f"i{q}" for q in range(N)), + *(f"j*{q}" if q in where else f"j{q}" for q in range(N)), + ) + Ei_inds = ("K", *(f"i{q}" for q in where), *(f"i*{q}" for q in where)) + Ej_inds = ("K", *(f"j{q}" for q in where), *(f"j*{q}" for q in where)) + out = (*(f"i{q}" for q in range(N)), *(f"j{q}" for q in range(N))) else: - rho_inds = ("ik", "jk") - out = ("inew", "jnew") - Ei_inds = ("K", "inew", "ik") - Ej_inds = ("K", "jnew", "jk") + Ei_inds = ("K", "i", "i*") + rho_inds = ("i*", "j*") + Ej_inds = ("K", "j", "j*") + out = ("i", "j") + + print(Ei_inds, rho_inds, Ej_inds, out) sigma = array_contract( (Ek, rho, Ek.conj()), diff --git a/tests/test_calc.py b/tests/test_calc.py index 17193672..8fdae395 100644 --- a/tests/test_calc.py +++ b/tests/test_calc.py @@ -148,11 +148,20 @@ def test_multisubsystem(self): else: assert p[0] == p[1] == 'I' K = qu.rand_iso(3 * 4, 4).reshape(3, 4, 4) - KIIXK = qu.kraus_op(IIX, K, dims=dims, where=[0, 2]) + KIIXK = qu.kraus_op(IIX, K, dims=dims, where=[0, 2], check=True) dcmp = qu.pauli_decomp(KIIXK, mode='c') for p, x in dcmp.items(): if abs(x) > 1e-12: - assert (p == 'III') or p[0] != 'I' + assert (p == 'III') or p[1] == 'I' + + @pytest.mark.parametrize("subsystem", [(0, 1), (1, 2), (2, 0)]) + def test_multisubsytem_kraus_identity(self, subsystem): + n = 3 + qu.seed_rand(7) + rho = qu.rand_rho(2**n) + Ek = np.array([qu.eye(2**len(subsystem))]) + sigma = qu.kraus_op(rho, Ek, dims=[2] * n, where=[0, 1], check=True) + assert qu.fidelity(rho, sigma) == pytest.approx(1.0) class TestProjector: