From 1deb2461814383c47792e9e5366ea2bffba45120 Mon Sep 17 00:00:00 2001 From: PabloAndresCQ Date: Fri, 29 Nov 2024 15:07:58 +0000 Subject: [PATCH] Now allowing apply_unitary to receive np.ndarray as well as cp.ndarray. --- .../cutensornet/structured_state/general.py | 4 ---- .../cutensornet/structured_state/mps.py | 15 ++++++++++----- .../cutensornet/structured_state/ttn.py | 15 ++++++++++----- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/pytket/extensions/cutensornet/structured_state/general.py b/pytket/extensions/cutensornet/structured_state/general.py index 0cc56230..dcaed56d 100644 --- a/pytket/extensions/cutensornet/structured_state/general.py +++ b/pytket/extensions/cutensornet/structured_state/general.py @@ -228,10 +228,6 @@ def _apply_command( except: raise ValueError(f"The command {op.type} introduced is not supported.") - # Load the gate's unitary to the GPU memory - unitary = unitary.astype(dtype=self._cfg._complex_t, copy=False) - unitary = cp.asarray(unitary, dtype=self._cfg._complex_t) - if len(qubits) not in [1, 2]: raise ValueError( "Gates must act on only 1 or 2 qubits! " diff --git a/pytket/extensions/cutensornet/structured_state/mps.py b/pytket/extensions/cutensornet/structured_state/mps.py index 0c430b6f..25835e86 100644 --- a/pytket/extensions/cutensornet/structured_state/mps.py +++ b/pytket/extensions/cutensornet/structured_state/mps.py @@ -18,6 +18,7 @@ from random import Random # type: ignore import numpy as np # type: ignore +from numpy.typing import NDArray # type: ignore try: import cupy as cp # type: ignore @@ -153,9 +154,7 @@ def is_valid(self) -> bool: return chi_ok and phys_ok and shape_ok and ds_ok - def apply_unitary( - self, unitary: cp.ndarray, qubits: list[Qubit] - ) -> StructuredState: + def apply_unitary(self, unitary: NDArray, qubits: list[Qubit]) -> StructuredState: """Applies the unitary to the specified qubits of the StructuredState. Note: @@ -163,8 +162,9 @@ def apply_unitary( not the case, the program will still run, but its behaviour is undefined. Args: - unitary: The matrix to be applied as a CuPy ndarray. It should either be - a 2x2 matrix if acting on one qubit or a 4x4 matrix if acting on two. + unitary: The matrix to be applied as a NumPy or CuPy ndarray. It should + either be a 2x2 matrix if acting on one qubit or a 4x4 matrix if acting + on two. qubits: The qubits the unitary acts on. Only one qubit and two qubit unitaries are supported. @@ -183,6 +183,11 @@ def apply_unitary( "See the documentation of update_libhandle and CuTensorNetHandle.", ) + if not isinstance(unitary, cp.ndarray): + # Load the gate's unitary to the GPU memory + unitary = unitary.astype(dtype=self._cfg._complex_t, copy=False) + unitary = cp.asarray(unitary, dtype=self._cfg._complex_t) + self._logger.debug(f"Applying unitary {unitary} on {qubits}.") if len(qubits) == 1: diff --git a/pytket/extensions/cutensornet/structured_state/ttn.py b/pytket/extensions/cutensornet/structured_state/ttn.py index 3a1a412e..2144471d 100644 --- a/pytket/extensions/cutensornet/structured_state/ttn.py +++ b/pytket/extensions/cutensornet/structured_state/ttn.py @@ -19,6 +19,7 @@ from random import Random # type: ignore import math # type: ignore import numpy as np # type: ignore +from numpy.typing import NDArray # type: ignore try: import cupy as cp # type: ignore @@ -242,9 +243,7 @@ def is_valid(self) -> bool: ) return chi_ok and phys_ok and rank_ok and shape_ok - def apply_unitary( - self, unitary: cp.ndarray, qubits: list[Qubit] - ) -> StructuredState: + def apply_unitary(self, unitary: NDArray, qubits: list[Qubit]) -> StructuredState: """Applies the unitary to the specified qubits of the StructuredState. Note: @@ -252,8 +251,9 @@ def apply_unitary( not the case, the program will still run, but its behaviour is undefined. Args: - unitary: The matrix to be applied as a CuPy ndarray. It should either be - a 2x2 matrix if acting on one qubit or a 4x4 matrix if acting on two. + unitary: The matrix to be applied as a NumPy or CuPy ndarray. It should + either be a 2x2 matrix if acting on one qubit or a 4x4 matrix if acting + on two. qubits: The qubits the unitary acts on. Only one qubit and two qubit unitaries are supported. @@ -272,6 +272,11 @@ def apply_unitary( "See the documentation of update_libhandle and CuTensorNetHandle.", ) + if not isinstance(unitary, cp.ndarray): + # Load the gate's unitary to the GPU memory + unitary = unitary.astype(dtype=self._cfg._complex_t, copy=False) + unitary = cp.asarray(unitary, dtype=self._cfg._complex_t) + self._logger.debug(f"Applying unitary {unitary} on {qubits}.") if len(qubits) == 1: