From 63d0f6c02a84487f98afa5b35e920bc7ec96f411 Mon Sep 17 00:00:00 2001 From: Melf Date: Fri, 4 Oct 2024 15:04:13 +0100 Subject: [PATCH] ruff fix II --- .../backends/cutensornet_backend.py | 10 +++---- pytket/extensions/cutensornet/general.py | 4 +-- .../general_state/tensor_network_convert.py | 16 +++++------ .../general_state/tensor_network_state.py | 25 +++++++++-------- .../cutensornet/structured_state/general.py | 16 ++++++----- .../cutensornet/structured_state/mps.py | 9 +++--- .../cutensornet/structured_state/mps_gate.py | 6 +++- .../cutensornet/structured_state/mps_mpo.py | 19 +++++++------ .../cutensornet/structured_state/ttn.py | 28 +++++++++---------- .../cutensornet/structured_state/ttn_gate.py | 11 ++++---- tests/test_general_state.py | 5 +--- tests/test_structured_state_conditionals.py | 2 +- tests/test_tensor_network_convert.py | 4 +-- 13 files changed, 80 insertions(+), 75 deletions(-) diff --git a/pytket/extensions/cutensornet/backends/cutensornet_backend.py b/pytket/extensions/cutensornet/backends/cutensornet_backend.py index 302981d0..a48b55ff 100644 --- a/pytket/extensions/cutensornet/backends/cutensornet_backend.py +++ b/pytket/extensions/cutensornet/backends/cutensornet_backend.py @@ -17,7 +17,7 @@ import warnings from abc import abstractmethod from collections.abc import Sequence -from typing import List, Optional, Union +from typing import Optional, Union from uuid import uuid4 from pytket.backends import CircuitNotRunError, CircuitStatus, ResultHandle, StatusEnum @@ -69,7 +69,7 @@ def _result_id_type(self) -> _ResultIdTuple: return (str,) @property - def required_predicates(self) -> List[Predicate]: + def required_predicates(self) -> list[Predicate]: """Returns the minimum set of predicates that a circuit must satisfy. Predicates need to be satisfied before the circuit can be successfully run on @@ -145,7 +145,7 @@ def process_circuits( n_shots: Optional[Union[int, Sequence[int]]] = None, valid_check: bool = True, **kwargs: KwargTypes, - ) -> List[ResultHandle]: + ) -> list[ResultHandle]: """Submits circuits to the backend for running. The results will be stored in the backend's result cache to be retrieved by the @@ -197,7 +197,7 @@ def process_circuits( n_shots: Optional[Union[int, Sequence[int]]] = None, valid_check: bool = True, **kwargs: KwargTypes, - ) -> List[ResultHandle]: + ) -> list[ResultHandle]: """Submits circuits to the backend for running. The results will be stored in the backend's result cache to be retrieved by the @@ -274,7 +274,7 @@ def process_circuits( n_shots: Optional[Union[int, Sequence[int]]] = None, valid_check: bool = True, **kwargs: KwargTypes, - ) -> List[ResultHandle]: + ) -> list[ResultHandle]: """Submits circuits to the backend for running. The results will be stored in the backend's result cache to be retrieved by the diff --git a/pytket/extensions/cutensornet/general.py b/pytket/extensions/cutensornet/general.py index 8971539e..115f6562 100644 --- a/pytket/extensions/cutensornet/general.py +++ b/pytket/extensions/cutensornet/general.py @@ -16,7 +16,7 @@ import logging import warnings from logging import Logger -from typing import Any, Optional +from typing import Any try: import cupy as cp # type: ignore @@ -42,7 +42,7 @@ class CuTensorNetHandle: If not provided, defaults to ``cp.cuda.Device()``. """ - def __init__(self, device_id: Optional[int] = None): + def __init__(self, device_id: int | None = None): self._is_destroyed = False # Make sure CuPy uses the specified device diff --git a/pytket/extensions/cutensornet/general_state/tensor_network_convert.py b/pytket/extensions/cutensornet/general_state/tensor_network_convert.py index 0afd36c5..2480088d 100644 --- a/pytket/extensions/cutensornet/general_state/tensor_network_convert.py +++ b/pytket/extensions/cutensornet/general_state/tensor_network_convert.py @@ -24,7 +24,7 @@ import logging from collections import defaultdict from logging import Logger -from typing import Any, DefaultDict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import networkx as nx # type: ignore import numpy as np @@ -88,7 +88,7 @@ def cuquantum_interleaved(self) -> list: """Returns an interleaved format of the circuit tensor network.""" return self._cuquantum_interleaved - def _get_gate_tensors(self, adj: bool = False) -> DefaultDict[Any, List[Any]]: + def _get_gate_tensors(self, adj: bool = False) -> defaultdict[Any, list[Any]]: """Computes and stores tensors for each gate type from the circuit. The unitaries are reshaped into tensors of bond dimension two prior to being @@ -174,7 +174,7 @@ def _get_gate_tensors(self, adj: bool = False) -> DefaultDict[Any, List[Any]]: self._logger.debug(f"Gate tensors: \n{gate_tensors}\n") return gate_tensors - def _assign_node_tensors(self, adj: bool = False) -> List[Any]: + def _assign_node_tensors(self, adj: bool = False) -> list[Any]: """Creates a list of tensors representing circuit gates (tensor network nodes). Args: @@ -236,7 +236,7 @@ def _assign_node_tensors(self, adj: bool = False) -> List[Any]: def _get_tn_indices( self, net: nx.MultiDiGraph, adj: bool = False - ) -> Tuple[List[Any], dict[Qubit, int]]: + ) -> tuple[list[Any], dict[Qubit, int]]: """Computes indices of the edges of the tensor network nodes (tensors). Indices are computed such that they range from high (for circuit leftmost gates) @@ -283,7 +283,7 @@ def _get_tn_indices( ] eids_sorted = sorted(eids, key=abs) qnames_graph_ordered = [qname for qname in self._graph.output_names.values()] - oids_graph_ordered = [oid for oid in self._graph.output_names.keys()] + oids_graph_ordered = [oid for oid in self._graph.output_names] eids_qubit_ordered = [ eids_sorted[qnames_graph_ordered.index(q)] for q in self._qubit_names_ilo ] # Order eid's in the same way as qnames_graph_ordered as compared to ILO @@ -363,7 +363,7 @@ def _get_tn_indices( @staticmethod def _order_edges_for_multiqubit_gate( - edge_indices: DefaultDict[Any, List[Tuple[Any, int]]], + edge_indices: defaultdict[Any, list[tuple[Any, int]]], edges: OutMultiEdgeView, edges_data: OutMultiEdgeDataView, offset: int, @@ -577,7 +577,7 @@ def __init__( self._logger = set_logger("PauliOperatorTensorNetwork", loglevel) self._pauli_tensors = [self.PAULI[pauli.name] for pauli in paulis.map.values()] self._logger.debug(f"Pauli tensors: {self._pauli_tensors}") - qubits = [q for q in paulis.map.keys()] + qubits = [q for q in paulis.map] # qubit_names = [ # "".join([q.reg_name, "".join([f"[{str(i)}]" for i in q.index])]) # for q in paulis.map.keys() @@ -655,7 +655,7 @@ def _make_interleaved(self) -> list: return tn_concatenated -def tk_to_tensor_network(tkc: Circuit) -> List[Union[NDArray, List]]: +def tk_to_tensor_network(tkc: Circuit) -> list[Union[NDArray, list]]: """Converts pytket circuit into a tensor network. Args: diff --git a/pytket/extensions/cutensornet/general_state/tensor_network_state.py b/pytket/extensions/cutensornet/general_state/tensor_network_state.py index 57f44aba..7efed94f 100644 --- a/pytket/extensions/cutensornet/general_state/tensor_network_state.py +++ b/pytket/extensions/cutensornet/general_state/tensor_network_state.py @@ -16,21 +16,24 @@ import logging import warnings -from typing import Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING try: import cupy as cp # type: ignore except ImportError: warnings.warn("local settings failed to import cupy", ImportWarning) import numpy as np -from numpy.typing import NDArray from sympy import Expr # type: ignore from pytket.backends.backendresult import BackendResult from pytket.circuit import Bit, Circuit, OpType, Qubit from pytket.extensions.cutensornet.general import CuTensorNetHandle, set_logger from pytket.utils import OutcomeArray -from pytket.utils.operators import QubitPauliOperator + +if TYPE_CHECKING: + from numpy.typing import NDArray + + from pytket.utils.operators import QubitPauliOperator try: import cuquantum as cq # type: ignore @@ -115,10 +118,10 @@ def __init__( def get_statevector( self, - attributes: Optional[dict] = None, + attributes: dict | None = None, scratch_fraction: float = 0.75, on_host: bool = True, - ) -> Union[cp.ndarray, np.ndarray]: + ) -> cp.ndarray | np.ndarray: """Contracts the circuit and returns the final statevector. Args: @@ -230,7 +233,7 @@ def get_statevector( def expectation_value( self, operator: QubitPauliOperator, - attributes: Optional[dict] = None, + attributes: dict | None = None, scratch_fraction: float = 0.75, ) -> complex: """Calculates the expectation value of the given operator. @@ -275,7 +278,7 @@ def expectation_value( self._logger.debug(f" {numeric_coeff}, {pauli_string}") # Raise an error if the operator acts on qubits that are not in the circuit - if any(q not in self._circuit.qubits for q in pauli_string.map.keys()): + if any(q not in self._circuit.qubits for q in pauli_string.map): raise ValueError( f"The operator is acting on qubits {pauli_string.map.keys()}, " "but some of these are not present in the circuit, whose set of " @@ -289,9 +292,7 @@ def expectation_value( num_pauli = len(qubit_pauli_map) num_modes = (1,) * num_pauli - state_modes = tuple( - (self._qubit_idx_map[qb],) for qb in qubit_pauli_map - ) + state_modes = tuple((self._qubit_idx_map[qb],) for qb in qubit_pauli_map) gate_data = tuple(tensor.data.ptr for tensor in qubit_pauli_map.values()) cutn.network_operator_append_product( @@ -408,7 +409,7 @@ def expectation_value( def sample( self, n_shots: int, - attributes: Optional[dict] = None, + attributes: dict | None = None, scratch_fraction: float = 0.75, ) -> BackendResult: """Obtains samples from the measurements at the end of the circuit. @@ -568,7 +569,7 @@ def _formatted_tensor(matrix: NDArray, n_qubits: int) -> cp.ndarray: return cupy_matrix.reshape([2] * (2 * n_qubits), order="F") -def _remove_meas_and_implicit_swaps(circ: Circuit) -> Tuple[Circuit, Dict[Qubit, Bit]]: +def _remove_meas_and_implicit_swaps(circ: Circuit) -> tuple[Circuit, dict[Qubit, Bit]]: """Convert a pytket Circuit to an equivalent circuit with no measurements or implicit swaps. The measurements are returned as a map between qubits and bits. diff --git a/pytket/extensions/cutensornet/structured_state/general.py b/pytket/extensions/cutensornet/structured_state/general.py index 0a5ae6fb..13541e82 100644 --- a/pytket/extensions/cutensornet/structured_state/general.py +++ b/pytket/extensions/cutensornet/structured_state/general.py @@ -16,7 +16,7 @@ import logging import warnings from abc import ABC, abstractmethod -from typing import Any, Optional, Type +from typing import TYPE_CHECKING, Any import numpy as np # type: ignore @@ -28,17 +28,19 @@ OpType, Qubit, ) -from pytket.pauli import QubitPauliString try: import cupy as cp # type: ignore except ImportError: warnings.warn("local settings failed to import cupy", ImportWarning) -from pytket.extensions.cutensornet import CuTensorNetHandle from .classical import apply_classical_command, from_little_endian +if TYPE_CHECKING: + from pytket.extensions.cutensornet import CuTensorNetHandle + from pytket.pauli import QubitPauliString + # An alias for the CuPy type used for tensors try: Tensor = cp.ndarray @@ -51,10 +53,10 @@ class Config: def __init__( self, - chi: Optional[int] = None, - truncation_fidelity: Optional[float] = None, - seed: Optional[int] = None, - float_precision: Type[Any] = np.float64, + chi: int | None = None, + truncation_fidelity: float | None = None, + seed: int | None = None, + float_precision: type[Any] = np.float64, value_of_zero: float = 1e-16, leaf_size: int = 8, k: int = 4, diff --git a/pytket/extensions/cutensornet/structured_state/mps.py b/pytket/extensions/cutensornet/structured_state/mps.py index 6380a006..46b5677d 100644 --- a/pytket/extensions/cutensornet/structured_state/mps.py +++ b/pytket/extensions/cutensornet/structured_state/mps.py @@ -16,7 +16,6 @@ import warnings from enum import Enum from random import Random # type: ignore -from typing import Optional, Union import numpy as np # type: ignore @@ -69,7 +68,7 @@ def __init__( libhandle: CuTensorNetHandle, qubits: list[Qubit], config: Config, - bits: Optional[list[Bit]] = None, + bits: list[Bit] | None = None, ): """Initialise an MPS on the computational state ``|0>`` @@ -269,7 +268,7 @@ def add_qubit(self, new_qubit: Qubit, position: int, state: int = 0) -> MPS: options = {"handle": self._lib.handle, "device_id": self._lib.device_id} - if new_qubit in self.qubit_position.keys(): + if new_qubit in self.qubit_position: raise ValueError( f"Qubit {new_qubit} cannot be added, it already is in the MPS." ) @@ -501,7 +500,7 @@ def vdot(self, other: MPS) -> complex: # type: ignore def _get_interleaved_representation( self, conj: bool = False - ) -> list[Union[cp.ndarray, str]]: + ) -> list[cp.ndarray | str]: """Returns the interleaved representation of the MPS used by cuQuantum. Args: @@ -760,7 +759,7 @@ def expectation_value(self, pauli_string: QubitPauliString) -> float: """ self._flush() - for q in pauli_string.map.keys(): + for q in pauli_string.map: if q not in self.qubit_position: raise ValueError(f"Qubit {q} is not a qubit in the MPS.") diff --git a/pytket/extensions/cutensornet/structured_state/mps_gate.py b/pytket/extensions/cutensornet/structured_state/mps_gate.py index 7fada0a1..83216582 100644 --- a/pytket/extensions/cutensornet/structured_state/mps_gate.py +++ b/pytket/extensions/cutensornet/structured_state/mps_gate.py @@ -27,10 +27,14 @@ except ImportError: warnings.warn("local settings failed to import cutensornet", ImportWarning) -from pytket.circuit import Qubit + +from typing import TYPE_CHECKING from .mps import MPS, DirMPS +if TYPE_CHECKING: + from pytket.circuit import Qubit + class MPSxGate(MPS): """Implements a gate-by-gate contraction algorithm to calculate the output state diff --git a/pytket/extensions/cutensornet/structured_state/mps_mpo.py b/pytket/extensions/cutensornet/structured_state/mps_mpo.py index 7ae84496..f483e473 100644 --- a/pytket/extensions/cutensornet/structured_state/mps_mpo.py +++ b/pytket/extensions/cutensornet/structured_state/mps_mpo.py @@ -14,7 +14,7 @@ from __future__ import annotations # type: ignore import warnings -from typing import Optional, Union +from typing import TYPE_CHECKING import numpy as np # type: ignore @@ -28,16 +28,19 @@ except ImportError: warnings.warn("local settings failed to import cutensornet", ImportWarning) -from pytket.circuit import Bit, Qubit -from pytket.extensions.cutensornet import CuTensorNetHandle -from .general import Config, Tensor from .mps import ( MPS, DirMPS, ) from .mps_gate import MPSxGate +if TYPE_CHECKING: + from pytket.circuit import Bit, Qubit + from pytket.extensions.cutensornet import CuTensorNetHandle + + from .general import Config, Tensor + class MPSxMPO(MPS): """Implements a batched--gate contraction algorithm (DMRG-like) to calculate @@ -50,7 +53,7 @@ def __init__( libhandle: CuTensorNetHandle, qubits: list[Qubit], config: Config, - bits: Optional[list[Bit]] = None, + bits: list[Bit] | None = None, ): """Initialise an MPS on the computational state ``|0>``. @@ -430,7 +433,7 @@ def update_sweep_cache(pos: int, direction: DirMPS) -> None: # The MPO tensor at this position interleaved_rep.append(mpo_tensor) - mpo_bonds: list[Union[int, str]] = list(self._bond_ids[pos][i]) + mpo_bonds: list[int | str] = list(self._bond_ids[pos][i]) if i == 0: # The input bond of the first MPO tensor must connect to the # physical bond of the correspondong ``self.tensors`` tensor @@ -479,7 +482,7 @@ def update_sweep_cache(pos: int, direction: DirMPS) -> None: self._logger.debug("Completed update of the sweep cache.") def update_variational_tensor( - pos: int, left_tensor: Optional[Tensor], right_tensor: Optional[Tensor] + pos: int, left_tensor: Tensor | None, right_tensor: Tensor | None ) -> float: """Update the tensor at ``pos`` of the variational MPS using ``left_tensor`` (and ``right_tensor``) which is meant to contain the contraction of all @@ -501,7 +504,7 @@ def update_variational_tensor( # The MPO tensor at this position interleaved_rep.append(mpo_tensor) - mpo_bonds: list[Union[int, str]] = list(self._bond_ids[pos][i]) + mpo_bonds: list[int | str] = list(self._bond_ids[pos][i]) if i == 0: # The input bond of the first MPO tensor must connect to the # physical bond of the correspondong ``self.tensors`` tensor diff --git a/pytket/extensions/cutensornet/structured_state/ttn.py b/pytket/extensions/cutensornet/structured_state/ttn.py index c63f5051..1b02f715 100644 --- a/pytket/extensions/cutensornet/structured_state/ttn.py +++ b/pytket/extensions/cutensornet/structured_state/ttn.py @@ -17,9 +17,7 @@ import warnings from enum import IntEnum from random import Random # type: ignore -from typing import Optional, Union - -import numpy as np # type: ignore +from typing import TYPE_CHECKING try: import cupy as cp # type: ignore @@ -33,10 +31,14 @@ from pytket.circuit import Bit, Qubit from pytket.extensions.cutensornet.general import CuTensorNetHandle, set_logger -from pytket.pauli import QubitPauliString from .general import Config, StructuredState, Tensor +if TYPE_CHECKING: + import numpy as np + + from pytket.pauli import QubitPauliString + class DirTTN(IntEnum): """An enum to refer to relative directions within the TTN.""" @@ -66,7 +68,7 @@ class TreeNode: def __init__(self, tensor: Tensor, is_leaf: bool = False): self.tensor = tensor self.is_leaf = is_leaf - self.canonical_form: Optional[DirTTN] = None + self.canonical_form: DirTTN | None = None def copy(self) -> TreeNode: new_node = TreeNode( @@ -96,7 +98,7 @@ def __init__( libhandle: CuTensorNetHandle, qubit_partition: dict[int, list[Qubit]], config: Config, - bits: Optional[list[Bit]] = None, + bits: list[Bit] | None = None, ): """Initialise a TTN on the computational state ``|0>``. @@ -215,7 +217,7 @@ def is_valid(self) -> bool: """ chi_ok = all( self.get_dimension(path, DirTTN.PARENT) <= self._cfg.chi - for path in self.nodes.keys() + for path in self.nodes ) phys_ok = all( self.nodes[path].tensor.shape[bond] == 2 @@ -227,7 +229,7 @@ def is_valid(self) -> bool: shape_ok = all( self.get_dimension(path, DirTTN.PARENT) == self.get_dimension(path[:-1], path[-1]) - for path in self.nodes.keys() + for path in self.nodes if len(path) != 0 ) shape_ok = shape_ok and self.get_dimension((), DirTTN.PARENT) == 1 @@ -332,9 +334,7 @@ def apply_qubit_relabelling(self, qubit_map: dict[Qubit, Qubit]) -> TTN: self._logger.debug(f"Relabelled qubits... {qubit_map}") return self - def canonicalise( - self, center: Union[RootPath, Qubit], unsafe: bool = False - ) -> Tensor: + def canonicalise(self, center: RootPath | Qubit, unsafe: bool = False) -> Tensor: """Canonicalise the TTN so that all tensors are isometries from ``center``. Args: @@ -368,7 +368,7 @@ def canonicalise( # Separate nodes to be canonicalised towards children from those towards parent towards_child = [] towards_parent = [] - for path in self.nodes.keys(): + for path in self.nodes: # Nodes towards children are closer to the root and coincide in the path if len(path) < len(target_path) and all( path[l] == target_path[l] for l in range(len(path)) @@ -759,9 +759,7 @@ def get_qubits(self) -> set[Qubit]: """Returns the set of qubits that this TTN is defined on.""" return set(self.qubit_position.keys()) - def get_interleaved_representation( - self, conj: bool = False - ) -> list[Union[Tensor, str]]: + def get_interleaved_representation(self, conj: bool = False) -> list[Tensor | str]: """Returns the interleaved representation of the TTN used by cuQuantum. Args: diff --git a/pytket/extensions/cutensornet/structured_state/ttn_gate.py b/pytket/extensions/cutensornet/structured_state/ttn_gate.py index 7f33c721..2e5452bb 100644 --- a/pytket/extensions/cutensornet/structured_state/ttn_gate.py +++ b/pytket/extensions/cutensornet/structured_state/ttn_gate.py @@ -26,10 +26,14 @@ except ImportError: warnings.warn("local settings failed to import cutensornet", ImportWarning) -from pytket.circuit import Qubit + +from typing import TYPE_CHECKING from .ttn import TTN, DirTTN, RootPath +if TYPE_CHECKING: + from pytket.circuit import Qubit + class TTNxGate(TTN): """Implements a gate-by-gate contraction algorithm to calculate the output state @@ -643,10 +647,7 @@ def _contract_decomp_bond_tensor_into_ttn( # Contract V to the parent node of the bond direction = bond_address[-1] - if direction == DirTTN.LEFT: - indices = "lrp,sl->srp" - else: - indices = "lrp,sr->lsp" + indices = "lrp,sl->srp" if direction == DirTTN.LEFT else "lrp,sr->lsp" self.nodes[bond_address[:-1]].tensor = cq.contract( indices, self.nodes[bond_address[:-1]].tensor, diff --git a/tests/test_general_state.py b/tests/test_general_state.py index 74c4108e..a0b7deaf 100644 --- a/tests/test_general_state.py +++ b/tests/test_general_state.py @@ -244,10 +244,7 @@ def test_sampler(circuit: Circuit, measure_all: bool) -> None: sv_pytket = circuit.get_statevector() # Add measurements to qubits - if measure_all: - num_measured = circuit.n_qubits - else: - num_measured = circuit.n_qubits // 2 + num_measured = circuit.n_qubits if measure_all else circuit.n_qubits // 2 for i, q in enumerate(circuit.qubits): if i < num_measured: # Skip the least significant qubits diff --git a/tests/test_structured_state_conditionals.py b/tests/test_structured_state_conditionals.py index 369db6d0..8e813d99 100644 --- a/tests/test_structured_state_conditionals.py +++ b/tests/test_structured_state_conditionals.py @@ -333,7 +333,7 @@ def test_correctness_copy_bits() -> None: cfg = Config() state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg) # Check that the copied register has the correct values - assert state.get_bits()[copied[0]] == False and state.get_bits()[copied[1]] == True + assert state.get_bits()[copied[0]] is False and state.get_bits()[copied[1]] is True def test_correctness_teleportation_bit() -> None: diff --git a/tests/test_tensor_network_convert.py b/tests/test_tensor_network_convert.py index fc0b0e12..8c246bfb 100644 --- a/tests/test_tensor_network_convert.py +++ b/tests/test_tensor_network_convert.py @@ -1,7 +1,7 @@ import cmath import random import warnings -from typing import List, Union +from typing import Union import numpy as np import pytest @@ -26,7 +26,7 @@ from pytket.utils.operators import QubitPauliOperator -def state_contract(tn: List[Union[NDArray, List]]) -> NDArray: +def state_contract(tn: list[Union[NDArray, list]]) -> NDArray: """Calls cuQuantum contract function to contract an input state tensor network.""" state_tn = tn.copy() state: NDArray = cq.contract(*state_tn).flatten()