From bd5b538b94dc1403e6f20180e144c9bc76c1e2b1 Mon Sep 17 00:00:00 2001 From: "Kevin J. Sung" Date: Fri, 19 Jul 2024 07:54:09 -0400 Subject: [PATCH] fix MPS sample handling of RNG seed --- quimb/tensor/circuit.py | 6 ++++-- tests/test_tensor/test_circuit.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/quimb/tensor/circuit.py b/quimb/tensor/circuit.py index 5723a5a3..e0667b4a 100644 --- a/quimb/tensor/circuit.py +++ b/quimb/tensor/circuit.py @@ -4043,7 +4043,8 @@ def sample( seed : None, int, or generator, optional A random seed or generator to use for reproducibility. """ - for config, _ in self._psi.sample(C, seed=seed): + rng = np.random.default_rng(seed) + for config, _ in self._psi.sample(C, seed=rng): yield "".join(map(str, config)) def fidelity_estimate(self): @@ -4154,9 +4155,10 @@ def sample(self, C, seed=None): str The next sample bitstring. """ + rng = np.random.default_rng(seed) # configuring is in physical order, so need to reorder for sampling ordering = self.calc_qubit_ordering() - for config, _ in self._psi.sample(C, seed=seed): + for config, _ in self._psi.sample(C, seed=rng): yield "".join(str(config[i]) for i in ordering) @property diff --git a/tests/test_tensor/test_circuit.py b/tests/test_tensor/test_circuit.py index a8a0ca57..b740f9b9 100644 --- a/tests/test_tensor/test_circuit.py +++ b/tests/test_tensor/test_circuit.py @@ -696,6 +696,13 @@ def test_mps_sampling(self): for x in circ.sample(10): assert x in {"000010", "111101"} + def test_mps_sampling_seed(self): + N = 1 + circ = qtn.CircuitMPS(N) + circ.h(0) + samples = list(circ.sample(10, seed=1234)) + assert len(set(samples)) == 2 + def test_permmps_sampling(self): N = 6 circ = qtn.CircuitPermMPS(N) @@ -710,6 +717,13 @@ def test_permmps_sampling(self): for x in circ.sample(10): assert x in {"000010", "111101"} + def test_permmps_sampling_seed(self): + N = 1 + circ = qtn.CircuitPermMPS(N) + circ.h(0) + samples = list(circ.sample(10, seed=1234)) + assert len(set(samples)) == 2 + class TestCircuitGen: @pytest.mark.parametrize(