diff --git a/pytket/extensions/cutensornet/mps/mps.py b/pytket/extensions/cutensornet/mps/mps.py index dec6426a..b7334d72 100644 --- a/pytket/extensions/cutensornet/mps/mps.py +++ b/pytket/extensions/cutensornet/mps/mps.py @@ -450,7 +450,13 @@ def canonicalise_tensor(self, pos: int, form: DirectionMPS) -> None: # Contract R into Tnext subscripts = R_bonds + "," + Tnext_bonds + "->" + result_bonds - result = cq.contract(subscripts, R, Tnext) + result = cq.contract( + subscripts, + R, + Tnext, + options=options, + optimize={"path": [(0, 1)]}, + ) self._logger.debug(f"Contraction with {next_pos} applied.") # Update self.tensors @@ -502,38 +508,69 @@ def vdot(self, other: MPS) -> complex: self._flush() other._flush() - # Special case if only one tensor remains - if len(self) == 1: - self._logger.debug("Applying trivial vdot on single tensor MPS.") - result = cq.contract("LRp,lrp->", self.tensors[0].conj(), other.tensors[0]) - - else: - self._logger.debug("Applying vdot between two MPS.") - - # The two MPS will be contracted from left to right, storing the - # ``partial_result`` tensor. - partial_result = cq.contract( - "LRp,lrp->Rr", self.tensors[0].conj(), other.tensors[0] - ) - # Contract all tensors in the middle - for pos in range(1, len(self) - 1): - partial_result = cq.contract( - "Ll,LRp,lrp->Rr", - partial_result, - self.tensors[pos].conj(), - other.tensors[pos], - ) - # Finally, contract the last tensor - result = cq.contract( - "Ll,LRp,lrp->", # R and r are dim 1, so they are omitted; scalar result - partial_result, - self.tensors[-1].conj(), - other.tensors[-1], - ) + self._logger.debug("Applying vdot between two MPS.") + + # We convert both MPS to their interleaved representation and + # contract them using cuQuantum. + mps1 = self._get_interleaved_representation(conj=True) + mps2 = other._get_interleaved_representation(conj=False) + interleaved_rep = mps1 + mps2 + interleaved_rep.append([]) # Discards dim=1 bonds with [] + + # We define the contraction path ourselves + end_mps1 = len(self) - 1 # Rightmost tensor of mps1 in interleaved_rep + end_mps2 = len(self) + len(other) - 1 # Rightmost tensor of mps2 + contraction_path = [(end_mps1, end_mps2)] # Contract ends of mps1 and mps2 + for _ in range(len(self) - 1): + # Update the position markers + end_mps1 -= 1 # One tensor was removed from mps1 + end_mps2 -= 2 # One tensor removed from mps1 and another from mps2 + # Contract the result from last iteration with the ends of mps1 and mps2 + contraction_path.append((end_mps2, end_mps2 + 1)) # End of mps2 and result + contraction_path.append((end_mps1, end_mps2)) # End of mps1 and ^ outcome + + # Apply the contraction + result = cq.contract( + *interleaved_rep, + options={"handle": self._lib.handle, "device_id": self._lib.device_id}, + optimize={"path": contraction_path}, + ) self._logger.debug(f"Result from vdot={result}") return complex(result) + def _get_interleaved_representation( + self, conj: bool = False + ) -> list[Union[cp.ndarray, str]]: + """Returns the interleaved representation of the MPS used by cuQuantum. + + Args: + conj: If True, all tensors are conjugated and bonds IDs are prefixed + with * (except physical bonds). Defaults to False. + """ + self._logger.debug("Creating interleaved representation...") + + # Auxiliar dictionary of physical bonds to qubit IDs + qubit_id = {location: qubit for qubit, location in self.qubit_position.items()} + + interleaved_rep = [] + for i, t in enumerate(self.tensors): + # Append the tensor + if conj: + interleaved_rep.append(t.conj()) + else: + interleaved_rep.append(t) + + # Create the ID for the bonds involved + bonds = [str(i), str(i + 1), str(qubit_id[i])] + if conj: + bonds[0] = "*" + bonds[0] + bonds[1] = "*" + bonds[1] + interleaved_rep.append(bonds) + self._logger.debug(f"Bond IDs: {bonds}") + + return interleaved_rep + def sample(self) -> dict[Qubit, int]: """Returns a sample from a Z measurement applied on every qubit. @@ -602,11 +639,13 @@ def measure(self, qubits: set[Qubit]) -> dict[Qubit, int]: # Since the MPS is in canonical form, this corresponds to the probability # if we were to take all of the other tensors into account. prob = cq.contract( - "lrp,lrP,p,P->", # No open bonds remain; this is just a scalar + "lrp,p,lrP,P->", # No open bonds remain; this is just a scalar self.tensors[pos].conj(), - self.tensors[pos], zero_tensor, + self.tensors[pos], zero_tensor, + options={"handle": self._lib.handle, "device_id": self._lib.device_id}, + optimize={"path": [(0, 1), (0, 1), (0, 1)]}, ) # Throw a coin to decide measurement outcome @@ -687,6 +726,8 @@ def _postselect_qubit(self, qubit: Qubit, postselection_tensor: cp.ndarray) -> N "lrp,p->lr", self.tensors[pos], postselection_tensor, + options={"handle": self._lib.handle, "device_id": self._lib.device_id}, + optimize={"path": [(0, 1)]}, ) # Glossary of bond IDs: @@ -704,6 +745,8 @@ def _postselect_qubit(self, qubit: Qubit, postselection_tensor: cp.ndarray) -> N "sv,VsP->VvP", self.tensors[pos], self.tensors[pos - 1], + options={"handle": self._lib.handle, "device_id": self._lib.device_id}, + optimize={"path": [(0, 1)]}, ) self.canonical_form[pos - 1] = None else: # There are no tensors on the left, contract with the one on the right @@ -711,6 +754,8 @@ def _postselect_qubit(self, qubit: Qubit, postselection_tensor: cp.ndarray) -> N "vs,sVP->vVP", self.tensors[pos], self.tensors[pos + 1], + options={"handle": self._lib.handle, "device_id": self._lib.device_id}, + optimize={"path": [(0, 1)]}, ) self.canonical_form[pos + 1] = None @@ -758,7 +803,14 @@ def expectation_value(self, pauli_string: QubitPauliString) -> float: # Contract the Pauli to the MPS tensor of the corresponding qubit mps_copy.tensors[pos] = cq.contract( - "lrp,Pp->lrP", mps_copy.tensors[pos], pauli_tensor + "lrp,Pp->lrP", + mps_copy.tensors[pos], + pauli_tensor, + options={ + "handle": self._lib.handle, + "device_id": self._lib.device_id, + }, + optimize={"path": [(0, 1)]}, ) # Obtain the inner product @@ -796,8 +848,22 @@ def get_statevector(self) -> np.ndarray: output_bonds.append("p" + str(self.qubit_position[q])) interleaved_rep.append(output_bonds) + # We define the contraction path ourselves + end_mps = len(self) - 1 + contraction_path = [(end_mps - 1, end_mps)] # Contract the last two tensors + end_mps -= 2 # Two tensors removed from the MPS + for _ in range(len(self) - 2): + # Contract the result from last iteration and the last tensor in the MPS + contraction_path.append((end_mps, end_mps + 1)) + # Update the position marker + end_mps -= 1 # One tensor was removed from the MPS + # Contract - result_tensor = cq.contract(*interleaved_rep) + result_tensor = cq.contract( + *interleaved_rep, + options={"handle": self._lib.handle, "device_id": self._lib.device_id}, + optimize={"path": contraction_path}, + ) # Convert to numpy vector and flatten statevector: np.ndarray = cp.asnumpy(result_tensor).flatten() @@ -818,6 +884,9 @@ def get_amplitude(self, state: int) -> complex: The amplitude of the computational state in the MPS. """ + # Auxiliar dictionary of physical bonds to qubit IDs + qubit_id = {location: qubit for qubit, location in self.qubit_position.items()} + # Find out what the map MPS_position -> bit value is ilo_qubits = sorted(self.qubit_position.keys()) mps_pos_bitvalue = dict() @@ -827,21 +896,37 @@ def get_amplitude(self, state: int) -> complex: bitvalue = 1 if state & 2 ** (len(self) - i - 1) else 0 mps_pos_bitvalue[pos] = bitvalue - # Carry out the contraction, starting from a dummy tensor - result_tensor = cp.ones(1, dtype=self._cfg._complex_t) # rank-1, dimension 1 - + # Create the interleaved representation including all postselection tensors + interleaved_rep = self._get_interleaved_representation() for pos in range(len(self)): postselection_tensor = cp.zeros(2, dtype=self._cfg._complex_t) postselection_tensor[mps_pos_bitvalue[pos]] = 1 - # Contract postselection with qubit into the result_tensor - result_tensor = cq.contract( - "l,lrp,p->r", result_tensor, self.tensors[pos], postselection_tensor - ) + interleaved_rep.append(postselection_tensor) + interleaved_rep.append([str(qubit_id[pos])]) + # Append [] so that all dim=1 bonds are ignored in the result of contract + interleaved_rep.append([]) + + # We define the contraction path ourselves + end_mps = len(self) - 1 # Rightmost tensor of MPS in interleaved_rep + end_rep = 2 * len(self) - 1 # Last position in the representation + contraction_path = [(end_mps, end_rep)] # Contract ends + for _ in range(len(self) - 1): + # Update the position markers + end_mps -= 1 # One tensor was removed from mps + end_rep -= 2 # One tensor removed from mps and another from postselect + # Contract the result from last iteration with the ends + contraction_path.append((end_mps, end_rep + 1)) # End of mps and result + contraction_path.append((end_rep - 1, end_rep)) # End of mps1 and ^ outcome + + # Apply the contraction + result = cq.contract( + *interleaved_rep, + options={"handle": self._lib.handle, "device_id": self._lib.device_id}, + optimize={"samples": 1}, + ) - assert result_tensor.shape == (1,) - result = complex(result_tensor[0]) self._logger.debug(f"Amplitude of state {state} is {result}.") - return result + return complex(result) def get_qubits(self) -> set[Qubit]: """Returns the set of qubits that this MPS is defined on.""" diff --git a/pytket/extensions/cutensornet/mps/mps_gate.py b/pytket/extensions/cutensornet/mps/mps_gate.py index c7dc496e..93961bf2 100644 --- a/pytket/extensions/cutensornet/mps/mps_gate.py +++ b/pytket/extensions/cutensornet/mps/mps_gate.py @@ -70,6 +70,8 @@ def _apply_1q_gate(self, position: int, gate: Op) -> MPSxGate: gate_bonds + "," + T_bonds + "->" + result_bonds, gate_tensor, self.tensors[position], + options={"handle": self._lib.handle, "device_id": self._lib.device_id}, + optimize={"path": [(0, 1)]}, ) # Update ``self.tensors`` @@ -145,6 +147,8 @@ def _apply_2q_gate(self, positions: tuple[int, int], gate: Op) -> MPSxGate: gate_tensor, self.tensors[l_pos], self.tensors[r_pos], + options={"handle": self._lib.handle, "device_id": self._lib.device_id}, + optimize={"path": [(0, 1), (0, 1)]}, ) self._logger.debug(f"Intermediate tensor of size (MiB)={T.nbytes / 2**20}") @@ -223,7 +227,13 @@ def _apply_2q_gate(self, positions: tuple[int, int], gate: Op) -> MPSxGate: # Use some einsum index magic: since the virtual bond "s" appears in the # list of bonds of the output, it is not summed over. # This causes S to act as the intended diagonal matrix. - L = cq.contract("asL,s->asL", L, S) + L = cq.contract( + "asL,s->asL", + L, + S, + options={"handle": self._lib.handle, "device_id": self._lib.device_id}, + optimize={"path": [(0, 1)]}, + ) # We multiply the fidelity of the current step to the overall fidelity # to keep track of a lower bound for the fidelity. diff --git a/pytket/extensions/cutensornet/mps/mps_mpo.py b/pytket/extensions/cutensornet/mps/mps_mpo.py index 47274262..2cb4e27d 100644 --- a/pytket/extensions/cutensornet/mps/mps_mpo.py +++ b/pytket/extensions/cutensornet/mps/mps_mpo.py @@ -144,6 +144,8 @@ def _apply_1q_gate(self, position: int, gate: Op) -> MPSxMPO: "go," + last_bonds + "->" + new_bonds, gate_tensor, last_tensor, + options={"handle": self._lib.handle, "device_id": self._lib.device_id}, + optimize={"path": [(0, 1)]}, ) # Update the tensor @@ -385,7 +387,11 @@ def update_sweep_cache(pos: int, direction: DirectionMPS) -> None: interleaved_rep.append(["r", "R"] + result_bonds) # Contract and store - T = cq.contract(*interleaved_rep) + T = cq.contract( + *interleaved_rep, + options={"handle": self._lib.handle, "device_id": self._lib.device_id}, + optimize={"samples": 1}, + ) if direction == DirectionMPS.LEFT: r_cached_tensors.append(T) elif direction == DirectionMPS.RIGHT: @@ -443,10 +449,25 @@ def update_variational_tensor( interleaved_rep.append(result_bonds) # Contract and store tensor - F = cq.contract(*interleaved_rep) + F = cq.contract( + *interleaved_rep, + options={"handle": self._lib.handle, "device_id": self._lib.device_id}, + optimize={"samples": 1}, + ) # Get the fidelity - optim_fidelity = complex(cq.contract("LRP,LRP->", F.conj(), F)) + optim_fidelity = complex( + cq.contract( + "LRP,LRP->", + F.conj(), + F, + options={ + "handle": self._lib.handle, + "device_id": self._lib.device_id, + }, + optimize={"path": [(0, 1)]}, + ) + ) assert np.isclose(optim_fidelity.imag, 0.0, atol=self._cfg._atol) optim_fidelity = float(optim_fidelity.real)