Skip to content

Commit

Permalink
GeneralState and GeneralBraOpKet as context managers
Browse files Browse the repository at this point in the history
  • Loading branch information
PabloAndresCQ committed Oct 23, 2024
1 parent 7240e55 commit 13f74d8
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 64 deletions.
14 changes: 7 additions & 7 deletions pytket/extensions/cutensornet/backends/cutensornet_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,10 @@ def process_circuits(
self._check_all_circuits(circuit_list)
handle_list = []
for circuit in circuit_list:
tn = GeneralState(
with GeneralState(
circuit, attributes=tnconfig, scratch_fraction=scratch_fraction
)
sv = tn.get_statevector()
) as tn:
sv = tn.get_statevector()
res_qubits = [qb for qb in sorted(circuit.qubits)]
handle = ResultHandle(str(uuid4()))
self._cache[handle] = {"result": BackendResult(q_bits=res_qubits, state=sv)}
Expand Down Expand Up @@ -310,11 +310,11 @@ def process_circuits(
self._check_all_circuits(circuit_list)
handle_list = []
for circuit, circ_shots in zip(circuit_list, all_shots):
tn = GeneralState(
circuit, attributes=tnconfig, scratch_fraction=scratch_fraction
)
handle = ResultHandle(str(uuid4()))
self._cache[handle] = {"result": tn.sample(circ_shots, seed=seed)}
with GeneralState(
circuit, attributes=tnconfig, scratch_fraction=scratch_fraction
) as tn:
self._cache[handle] = {"result": tn.sample(circ_shots, seed=seed)}
handle_list.append(handle)
return handle_list

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations
import logging
from typing import Union, Optional, Tuple, Dict
from typing import Union, Optional, Tuple, Dict, Any
import warnings

try:
Expand All @@ -37,7 +37,12 @@


class GeneralState: # TODO: Write it as a context manager so that I can call free()
"""Wrapper of cuTensorNet's NetworkState for exact simulation of states."""
"""Wrapper of cuTensorNet's NetworkState for exact simulation of states.
Note:
Preferably used as ``with GeneralState(...) as state:`` so that GPU memory is
automatically released after execution.
"""

def __init__(
self,
Expand Down Expand Up @@ -312,16 +317,27 @@ def destroy(self) -> None:
"""Destroy the tensor network and free up GPU memory.
Note:
Users are required to call `destroy()` when done using a
`GeneralState` object. GPU memory deallocation is not
guaranteed otherwise.
The preferred approach is to use a context manager as in
``with GeneralState(...) as state:``. Otherwise, the user must release
memory explicitly by calling ``destroy()``.
"""
self._logger.debug("Freeing memory of GeneralState")
self.tn_state.free()

def __enter__(self) -> GeneralState:
return self

def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
self.destroy()


class GeneralBraOpKet: # TODO: Write it as a context manager
"""Wrapper of cuTensorNet's NetworkState for exact simulation of ``<bra|O|ket>``."""
"""Wrapper of cuTensorNet's NetworkState for exact simulation of ``<bra|O|ket>``.
Note:
Preferably used as ``with GeneralBraOpKet(...) as braket:`` so that GPU memory
is automatically released after execution.
"""

def __init__(
self,
Expand Down Expand Up @@ -568,13 +584,19 @@ def destroy(self) -> None:
"""Destroy the tensor network and free up GPU memory.
Note:
Users are required to call `destroy()` when done using a
`GeneralState` object. GPU memory deallocation is not
guaranteed otherwise.
The preferred approach is to use a context manager as in
``with GeneralBraOpKet(...) as braket:``. Otherwise, the user must release
memory explicitly by calling ``destroy()``.
"""
self._logger.debug("Freeing memory of GeneralBraOpKet")
self.tn.free()

def __enter__(self) -> GeneralBraOpKet:
return self

def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
self.destroy()


def _formatted_tensor(matrix: NDArray, n_qubits: int) -> cp.ndarray:
"""Convert a matrix to the tensor format accepted by NVIDIA's API."""
Expand Down
82 changes: 34 additions & 48 deletions tests/test_general_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,35 +39,31 @@
],
)
def test_basic_circs_state(circuit: Circuit) -> None:
state = GeneralState(circuit)
sv = state.get_statevector()

sv_pytket = circuit.get_statevector()
assert np.allclose(sv, sv_pytket, atol=1e-10)

op = QubitPauliOperator(
{
QubitPauliString({q: Pauli.I for q in circuit.qubits}): 1.0,
}
)

# Calculate the inner product as the expectation value
# of the identity operator: <psi|psi> = <psi|I|psi>
state = GeneralState(circuit)
ovl = state.expectation_value(op)
assert ovl == pytest.approx(1.0)
with GeneralState(circuit) as state:
sv = state.get_statevector()
assert np.allclose(sv, sv_pytket, atol=1e-10)

# Check that all amplitudes agree
for i in range(len(sv)):
assert np.isclose(sv[i], state.get_amplitude(i))
# Calculate the inner product as the expectation value
# of the identity operator: <psi|psi> = <psi|I|psi>
ovl = state.expectation_value(op)
assert ovl == pytest.approx(1.0)

# Calculate the inner product again, using GeneralBraOpKet
braket = GeneralBraOpKet(circuit, circuit)
ovl = braket.contract()
assert ovl == pytest.approx(1.0)
# Check that all amplitudes agree
for i in range(len(sv)):
assert np.isclose(sv[i], state.get_amplitude(i))

braket.destroy()
state.destroy()
# Calculate the inner product again, using GeneralBraOpKet
with GeneralBraOpKet(circuit, circuit) as braket:
ovl = braket.contract()
assert ovl == pytest.approx(1.0)


def test_sv_toffoli_box_with_implicit_swaps() -> None:
Expand All @@ -92,9 +88,8 @@ def test_sv_toffoli_box_with_implicit_swaps() -> None:
Transform.OptimiseCliffords().apply(ket_circ)

# Convert and contract
state = GeneralState(ket_circ)
ket_net_vector = state.get_statevector()
state.destroy()
with GeneralState(ket_circ) as state:
ket_net_vector = state.get_statevector()

# Compare to pytket statevector
ket_pytket_vector = ket_circ.get_statevector()
Expand Down Expand Up @@ -127,8 +122,8 @@ def to_bool_tuple(n_qubits: int, x: int) -> tuple:
CnXPairwiseDecomposition().apply(ket_circ)
Transform.OptimiseCliffords().apply(ket_circ)

state = GeneralState(ket_circ)
ket_net_vector = state.get_statevector()
with GeneralState(ket_circ) as state:
ket_net_vector = state.get_statevector()

ket_pytket_vector = ket_circ.get_statevector()
assert np.allclose(ket_net_vector, ket_pytket_vector)
Expand All @@ -141,12 +136,10 @@ def to_bool_tuple(n_qubits: int, x: int) -> tuple:
}
)

state = GeneralState(ket_circ)
ovl = state.expectation_value(op)
with GeneralState(ket_circ) as state:
ovl = state.expectation_value(op)
assert ovl == pytest.approx(1.0)

state.destroy()


@pytest.mark.parametrize(
"circuit",
Expand Down Expand Up @@ -204,18 +197,16 @@ def test_expectation_value(circuit: Circuit, observable: QubitPauliOperator) ->
exp_val_tket = observable.state_expectation(circuit.get_statevector())

# Calculate using GeneralState
state = GeneralState(circuit)
exp_val = state.expectation_value(observable)
with GeneralState(circuit) as state:
exp_val = state.expectation_value(observable)

assert np.isclose(exp_val, exp_val_tket)
state.destroy()

# Calculate using GeneralBraOpKet
braket = GeneralBraOpKet(circuit, circuit)
exp_val = braket.contract(observable)
with GeneralBraOpKet(circuit, circuit) as braket:
exp_val = braket.contract(observable)

assert np.isclose(exp_val, exp_val_tket)
braket.destroy()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -266,8 +257,8 @@ def test_sampler(circuit: Circuit, measure_all: bool) -> None:
circuit.Measure(q, Bit(i))

# Sample using our library
state = GeneralState(circuit)
results = state.sample(n_shots)
with GeneralState(circuit) as state:
results = state.sample(n_shots)

# Verify distribution matches theoretical probabilities
for bit_tuple, count in results.get_counts().items():
Expand All @@ -288,8 +279,6 @@ def test_sampler(circuit: Circuit, measure_all: bool) -> None:

assert np.isclose(count / n_shots, prob, atol=0.01)

state.destroy()


@pytest.mark.parametrize(
"circuit",
Expand Down Expand Up @@ -321,18 +310,15 @@ def test_parameterised(circuit: Circuit, symbol_map: dict[Symbol, float]) -> Non

# Calculate the inner product as the expectation value
# of the identity operator: <psi|psi> = <psi|I|psi>
state = GeneralState(circuit)
ovl = state.expectation_value(op)
assert ovl == pytest.approx(1.0)
with GeneralState(circuit) as state:
ovl = state.expectation_value(op)
assert ovl == pytest.approx(1.0)

# Check that all amplitudes agree
for i in range(len(sv)):
assert np.isclose(sv[i], state.get_amplitude(i))
# Check that all amplitudes agree
for i in range(len(sv)):
assert np.isclose(sv[i], state.get_amplitude(i))

# Calculate the inner product again, using GeneralBraOpKet
braket = GeneralBraOpKet(circuit, circuit)
ovl = braket.contract()
with GeneralBraOpKet(circuit, circuit) as braket:
ovl = braket.contract()
assert ovl == pytest.approx(1.0)

braket.destroy()
state.destroy()

0 comments on commit 13f74d8

Please sign in to comment.