From 7ad5fc2a80b0dc7a33662ac9d82a890da59618ee Mon Sep 17 00:00:00 2001 From: PabloAndresCQ Date: Wed, 23 Oct 2024 09:54:41 +0000 Subject: [PATCH] Added support for RNG seeds --- .../cutensornet/backends/cutensornet_backend.py | 12 ++++-------- .../general_state/tensor_network_state.py | 8 +++++++- tests/test_cutensornet_backend.py | 16 ++++++++++++++++ 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/pytket/extensions/cutensornet/backends/cutensornet_backend.py b/pytket/extensions/cutensornet/backends/cutensornet_backend.py index d6c4b692..94ee1b96 100644 --- a/pytket/extensions/cutensornet/backends/cutensornet_backend.py +++ b/pytket/extensions/cutensornet/backends/cutensornet_backend.py @@ -281,6 +281,8 @@ def process_circuits( n_shots: Number of shots in case of shot-based calculation. Optionally, this can be a list of shots specifying the number of shots for each circuit separately. + seed: An optional RNG seed. Different calls to ``process_circuits`` with the + same seed will generate the same list of shot outcomes. valid_check: Whether to check for circuit correctness. tnconfig: Optional. A dict of cuTensorNet ``TNConfig`` keys and their values. @@ -292,13 +294,7 @@ def process_circuits( """ scratch_fraction = float(kwargs.get("scratch_fraction", 0.8)) # type: ignore tnconfig = kwargs.get("tnconfig", dict()) # type: ignore - - if "seed" in kwargs and kwargs["seed"] is not None: - # Current CuTensorNet does not support seeds for Sampler. I created - # a feature request in their repository. - raise NotImplementedError( # TODO: Support seeds! - "The backend does not currently support user-defined seeds." - ) + seed = kwargs.get("seed", None) if n_shots is None: raise ValueError( @@ -318,7 +314,7 @@ def process_circuits( circuit, attributes=tnconfig, scratch_fraction=scratch_fraction ) handle = ResultHandle(str(uuid4())) - self._cache[handle] = {"result": tn.sample(circ_shots)} + self._cache[handle] = {"result": tn.sample(circ_shots, seed=seed)} handle_list.append(handle) return handle_list diff --git a/pytket/extensions/cutensornet/general_state/tensor_network_state.py b/pytket/extensions/cutensornet/general_state/tensor_network_state.py index e8f8135d..61f79375 100644 --- a/pytket/extensions/cutensornet/general_state/tensor_network_state.py +++ b/pytket/extensions/cutensornet/general_state/tensor_network_state.py @@ -240,10 +240,11 @@ def expectation_value( self._logger.debug("(Expectation value) contracting the TN") return complex(self.tn_state.compute_expectation(tn_operator)) - def sample( # TODO: Support seeds (and test) + def sample( self, n_shots: int, symbol_map: Optional[dict[Symbol, float]] = None, + seed: Optional[int] = None, ) -> BackendResult: """Obtains samples from the measurements at the end of the circuit. @@ -251,6 +252,8 @@ def sample( # TODO: Support seeds (and test) n_shots: The number of samples to obtain. symbol_map: A dictionary where each element of ``sef.free_symbols`` is assigned a real number. + seed: An optional RNG seed. Different calls to ``sample`` with the same + seed will generate the same list of shot outcomes. Returns: A pytket ``BackendResult`` with the data from the shots. @@ -275,9 +278,12 @@ def sample( # TODO: Support seeds (and test) measured_modes = tuple(self._qubit_idx_map[qb] for qb in qbit_list) self._logger.debug("(Sampling) contracting the TN") + if seed is not None: + seed = abs(seed) # Must be a positive integer samples = self.tn_state.compute_sampling( nshots=n_shots, modes=measured_modes, + seed=seed, ) # Convert the data in `samples` to an `OutcomeArray` using `from_readouts` diff --git a/tests/test_cutensornet_backend.py b/tests/test_cutensornet_backend.py index 25c9d31c..80816b22 100644 --- a/tests/test_cutensornet_backend.py +++ b/tests/test_cutensornet_backend.py @@ -34,6 +34,22 @@ def test_sampler_bell() -> None: assert np.isclose(res.get_counts()[(1, 1)] / n_shots, 0.5, atol=0.01) +def test_sampler_bell_seed() -> None: + n_shots = 1000 + c = Circuit(2, 2) + c.H(0) + c.CX(0, 1) + c.measure_all() + b = CuTensorNetShotsBackend() + c = b.get_compiled_circuit(c) + res1 = b.run_circuit(c, n_shots=n_shots, seed=1234) + res2 = b.run_circuit(c, n_shots=n_shots, seed=1234) + res3 = b.run_circuit(c, n_shots=n_shots, seed=4321) + print(type(res1.get_shots())) + assert np.all(res1.get_shots() == res2.get_shots()) + assert np.any(res1.get_shots() != res3.get_shots()) + + def test_config_options() -> None: c = Circuit(2, 2) c.H(0)