From 13f74d83d24ad66632889ec17f49efd26388e16b Mon Sep 17 00:00:00 2001 From: PabloAndresCQ Date: Wed, 23 Oct 2024 11:08:24 +0000 Subject: [PATCH] GeneralState and GeneralBraOpKet as context managers --- .../backends/cutensornet_backend.py | 14 ++-- .../general_state/tensor_network_state.py | 40 +++++++-- tests/test_general_state.py | 82 ++++++++----------- 3 files changed, 72 insertions(+), 64 deletions(-) diff --git a/pytket/extensions/cutensornet/backends/cutensornet_backend.py b/pytket/extensions/cutensornet/backends/cutensornet_backend.py index 94ee1b96..4d8a0386 100644 --- a/pytket/extensions/cutensornet/backends/cutensornet_backend.py +++ b/pytket/extensions/cutensornet/backends/cutensornet_backend.py @@ -221,10 +221,10 @@ def process_circuits( self._check_all_circuits(circuit_list) handle_list = [] for circuit in circuit_list: - tn = GeneralState( + with GeneralState( circuit, attributes=tnconfig, scratch_fraction=scratch_fraction - ) - sv = tn.get_statevector() + ) as tn: + sv = tn.get_statevector() res_qubits = [qb for qb in sorted(circuit.qubits)] handle = ResultHandle(str(uuid4())) self._cache[handle] = {"result": BackendResult(q_bits=res_qubits, state=sv)} @@ -310,11 +310,11 @@ def process_circuits( self._check_all_circuits(circuit_list) handle_list = [] for circuit, circ_shots in zip(circuit_list, all_shots): - tn = GeneralState( - circuit, attributes=tnconfig, scratch_fraction=scratch_fraction - ) handle = ResultHandle(str(uuid4())) - self._cache[handle] = {"result": tn.sample(circ_shots, seed=seed)} + with GeneralState( + circuit, attributes=tnconfig, scratch_fraction=scratch_fraction + ) as tn: + self._cache[handle] = {"result": tn.sample(circ_shots, seed=seed)} handle_list.append(handle) return handle_list diff --git a/pytket/extensions/cutensornet/general_state/tensor_network_state.py b/pytket/extensions/cutensornet/general_state/tensor_network_state.py index 61f79375..8123b81e 100644 --- a/pytket/extensions/cutensornet/general_state/tensor_network_state.py +++ b/pytket/extensions/cutensornet/general_state/tensor_network_state.py @@ -14,7 +14,7 @@ from __future__ import annotations import logging -from typing import Union, Optional, Tuple, Dict +from typing import Union, Optional, Tuple, Dict, Any import warnings try: @@ -37,7 +37,12 @@ class GeneralState: # TODO: Write it as a context manager so that I can call free() - """Wrapper of cuTensorNet's NetworkState for exact simulation of states.""" + """Wrapper of cuTensorNet's NetworkState for exact simulation of states. + + Note: + Preferably used as ``with GeneralState(...) as state:`` so that GPU memory is + automatically released after execution. + """ def __init__( self, @@ -312,16 +317,27 @@ def destroy(self) -> None: """Destroy the tensor network and free up GPU memory. Note: - Users are required to call `destroy()` when done using a - `GeneralState` object. GPU memory deallocation is not - guaranteed otherwise. + The preferred approach is to use a context manager as in + ``with GeneralState(...) as state:``. Otherwise, the user must release + memory explicitly by calling ``destroy()``. """ self._logger.debug("Freeing memory of GeneralState") self.tn_state.free() + def __enter__(self) -> GeneralState: + return self + + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: + self.destroy() + class GeneralBraOpKet: # TODO: Write it as a context manager - """Wrapper of cuTensorNet's NetworkState for exact simulation of ````.""" + """Wrapper of cuTensorNet's NetworkState for exact simulation of ````. + + Note: + Preferably used as ``with GeneralBraOpKet(...) as braket:`` so that GPU memory + is automatically released after execution. + """ def __init__( self, @@ -568,13 +584,19 @@ def destroy(self) -> None: """Destroy the tensor network and free up GPU memory. Note: - Users are required to call `destroy()` when done using a - `GeneralState` object. GPU memory deallocation is not - guaranteed otherwise. + The preferred approach is to use a context manager as in + ``with GeneralBraOpKet(...) as braket:``. Otherwise, the user must release + memory explicitly by calling ``destroy()``. """ self._logger.debug("Freeing memory of GeneralBraOpKet") self.tn.free() + def __enter__(self) -> GeneralBraOpKet: + return self + + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: + self.destroy() + def _formatted_tensor(matrix: NDArray, n_qubits: int) -> cp.ndarray: """Convert a matrix to the tensor format accepted by NVIDIA's API.""" diff --git a/tests/test_general_state.py b/tests/test_general_state.py index b963ec8a..77ea5998 100644 --- a/tests/test_general_state.py +++ b/tests/test_general_state.py @@ -39,11 +39,7 @@ ], ) def test_basic_circs_state(circuit: Circuit) -> None: - state = GeneralState(circuit) - sv = state.get_statevector() - sv_pytket = circuit.get_statevector() - assert np.allclose(sv, sv_pytket, atol=1e-10) op = QubitPauliOperator( { @@ -51,23 +47,23 @@ def test_basic_circs_state(circuit: Circuit) -> None: } ) - # Calculate the inner product as the expectation value - # of the identity operator: = - state = GeneralState(circuit) - ovl = state.expectation_value(op) - assert ovl == pytest.approx(1.0) + with GeneralState(circuit) as state: + sv = state.get_statevector() + assert np.allclose(sv, sv_pytket, atol=1e-10) - # Check that all amplitudes agree - for i in range(len(sv)): - assert np.isclose(sv[i], state.get_amplitude(i)) + # Calculate the inner product as the expectation value + # of the identity operator: = + ovl = state.expectation_value(op) + assert ovl == pytest.approx(1.0) - # Calculate the inner product again, using GeneralBraOpKet - braket = GeneralBraOpKet(circuit, circuit) - ovl = braket.contract() - assert ovl == pytest.approx(1.0) + # Check that all amplitudes agree + for i in range(len(sv)): + assert np.isclose(sv[i], state.get_amplitude(i)) - braket.destroy() - state.destroy() + # Calculate the inner product again, using GeneralBraOpKet + with GeneralBraOpKet(circuit, circuit) as braket: + ovl = braket.contract() + assert ovl == pytest.approx(1.0) def test_sv_toffoli_box_with_implicit_swaps() -> None: @@ -92,9 +88,8 @@ def test_sv_toffoli_box_with_implicit_swaps() -> None: Transform.OptimiseCliffords().apply(ket_circ) # Convert and contract - state = GeneralState(ket_circ) - ket_net_vector = state.get_statevector() - state.destroy() + with GeneralState(ket_circ) as state: + ket_net_vector = state.get_statevector() # Compare to pytket statevector ket_pytket_vector = ket_circ.get_statevector() @@ -127,8 +122,8 @@ def to_bool_tuple(n_qubits: int, x: int) -> tuple: CnXPairwiseDecomposition().apply(ket_circ) Transform.OptimiseCliffords().apply(ket_circ) - state = GeneralState(ket_circ) - ket_net_vector = state.get_statevector() + with GeneralState(ket_circ) as state: + ket_net_vector = state.get_statevector() ket_pytket_vector = ket_circ.get_statevector() assert np.allclose(ket_net_vector, ket_pytket_vector) @@ -141,12 +136,10 @@ def to_bool_tuple(n_qubits: int, x: int) -> tuple: } ) - state = GeneralState(ket_circ) - ovl = state.expectation_value(op) + with GeneralState(ket_circ) as state: + ovl = state.expectation_value(op) assert ovl == pytest.approx(1.0) - state.destroy() - @pytest.mark.parametrize( "circuit", @@ -204,18 +197,16 @@ def test_expectation_value(circuit: Circuit, observable: QubitPauliOperator) -> exp_val_tket = observable.state_expectation(circuit.get_statevector()) # Calculate using GeneralState - state = GeneralState(circuit) - exp_val = state.expectation_value(observable) + with GeneralState(circuit) as state: + exp_val = state.expectation_value(observable) assert np.isclose(exp_val, exp_val_tket) - state.destroy() # Calculate using GeneralBraOpKet - braket = GeneralBraOpKet(circuit, circuit) - exp_val = braket.contract(observable) + with GeneralBraOpKet(circuit, circuit) as braket: + exp_val = braket.contract(observable) assert np.isclose(exp_val, exp_val_tket) - braket.destroy() @pytest.mark.parametrize( @@ -266,8 +257,8 @@ def test_sampler(circuit: Circuit, measure_all: bool) -> None: circuit.Measure(q, Bit(i)) # Sample using our library - state = GeneralState(circuit) - results = state.sample(n_shots) + with GeneralState(circuit) as state: + results = state.sample(n_shots) # Verify distribution matches theoretical probabilities for bit_tuple, count in results.get_counts().items(): @@ -288,8 +279,6 @@ def test_sampler(circuit: Circuit, measure_all: bool) -> None: assert np.isclose(count / n_shots, prob, atol=0.01) - state.destroy() - @pytest.mark.parametrize( "circuit", @@ -321,18 +310,15 @@ def test_parameterised(circuit: Circuit, symbol_map: dict[Symbol, float]) -> Non # Calculate the inner product as the expectation value # of the identity operator: = - state = GeneralState(circuit) - ovl = state.expectation_value(op) - assert ovl == pytest.approx(1.0) + with GeneralState(circuit) as state: + ovl = state.expectation_value(op) + assert ovl == pytest.approx(1.0) - # Check that all amplitudes agree - for i in range(len(sv)): - assert np.isclose(sv[i], state.get_amplitude(i)) + # Check that all amplitudes agree + for i in range(len(sv)): + assert np.isclose(sv[i], state.get_amplitude(i)) # Calculate the inner product again, using GeneralBraOpKet - braket = GeneralBraOpKet(circuit, circuit) - ovl = braket.contract() + with GeneralBraOpKet(circuit, circuit) as braket: + ovl = braket.contract() assert ovl == pytest.approx(1.0) - - braket.destroy() - state.destroy()