Skip to content

Commit

Permalink
Now allowing apply_unitary to receive np.ndarray as well as cp.ndarray.
Browse files Browse the repository at this point in the history
  • Loading branch information
PabloAndresCQ committed Nov 29, 2024
1 parent 5716d28 commit 1deb246
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
4 changes: 0 additions & 4 deletions pytket/extensions/cutensornet/structured_state/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,6 @@ def _apply_command(
except:
raise ValueError(f"The command {op.type} introduced is not supported.")

# Load the gate's unitary to the GPU memory
unitary = unitary.astype(dtype=self._cfg._complex_t, copy=False)
unitary = cp.asarray(unitary, dtype=self._cfg._complex_t)

if len(qubits) not in [1, 2]:
raise ValueError(
"Gates must act on only 1 or 2 qubits! "
Expand Down
15 changes: 10 additions & 5 deletions pytket/extensions/cutensornet/structured_state/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from random import Random # type: ignore
import numpy as np # type: ignore
from numpy.typing import NDArray # type: ignore

try:
import cupy as cp # type: ignore
Expand Down Expand Up @@ -153,18 +154,17 @@ def is_valid(self) -> bool:

return chi_ok and phys_ok and shape_ok and ds_ok

def apply_unitary(
self, unitary: cp.ndarray, qubits: list[Qubit]
) -> StructuredState:
def apply_unitary(self, unitary: NDArray, qubits: list[Qubit]) -> StructuredState:
"""Applies the unitary to the specified qubits of the StructuredState.
Note:
It is assumed that the matrix provided by the user is unitary. If this is
not the case, the program will still run, but its behaviour is undefined.
Args:
unitary: The matrix to be applied as a CuPy ndarray. It should either be
a 2x2 matrix if acting on one qubit or a 4x4 matrix if acting on two.
unitary: The matrix to be applied as a NumPy or CuPy ndarray. It should
either be a 2x2 matrix if acting on one qubit or a 4x4 matrix if acting
on two.
qubits: The qubits the unitary acts on. Only one qubit and two qubit
unitaries are supported.
Expand All @@ -183,6 +183,11 @@ def apply_unitary(
"See the documentation of update_libhandle and CuTensorNetHandle.",
)

if not isinstance(unitary, cp.ndarray):
# Load the gate's unitary to the GPU memory
unitary = unitary.astype(dtype=self._cfg._complex_t, copy=False)
unitary = cp.asarray(unitary, dtype=self._cfg._complex_t)

self._logger.debug(f"Applying unitary {unitary} on {qubits}.")

if len(qubits) == 1:
Expand Down
15 changes: 10 additions & 5 deletions pytket/extensions/cutensornet/structured_state/ttn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from random import Random # type: ignore
import math # type: ignore
import numpy as np # type: ignore
from numpy.typing import NDArray # type: ignore

try:
import cupy as cp # type: ignore
Expand Down Expand Up @@ -242,18 +243,17 @@ def is_valid(self) -> bool:
)
return chi_ok and phys_ok and rank_ok and shape_ok

def apply_unitary(
self, unitary: cp.ndarray, qubits: list[Qubit]
) -> StructuredState:
def apply_unitary(self, unitary: NDArray, qubits: list[Qubit]) -> StructuredState:
"""Applies the unitary to the specified qubits of the StructuredState.
Note:
It is assumed that the matrix provided by the user is unitary. If this is
not the case, the program will still run, but its behaviour is undefined.
Args:
unitary: The matrix to be applied as a CuPy ndarray. It should either be
a 2x2 matrix if acting on one qubit or a 4x4 matrix if acting on two.
unitary: The matrix to be applied as a NumPy or CuPy ndarray. It should
either be a 2x2 matrix if acting on one qubit or a 4x4 matrix if acting
on two.
qubits: The qubits the unitary acts on. Only one qubit and two qubit
unitaries are supported.
Expand All @@ -272,6 +272,11 @@ def apply_unitary(
"See the documentation of update_libhandle and CuTensorNetHandle.",
)

if not isinstance(unitary, cp.ndarray):
# Load the gate's unitary to the GPU memory
unitary = unitary.astype(dtype=self._cfg._complex_t, copy=False)
unitary = cp.asarray(unitary, dtype=self._cfg._complex_t)

self._logger.debug(f"Applying unitary {unitary} on {qubits}.")

if len(qubits) == 1:
Expand Down

0 comments on commit 1deb246

Please sign in to comment.