Skip to content

Commit

Permalink
Improve contract performance (#36)
Browse files Browse the repository at this point in the history
* Updated the way vdot and get_amplitude are calculated. Added library handle to all calls to cq.contract.

* Now explicitly providing contraction path for simple cases.

* Now vdot, get_statevector and get_amplitude all use cq.contract with a predefined (generally optimal) path.
  • Loading branch information
PabloAndresCQ authored Oct 26, 2023
1 parent 8327268 commit d0886c0
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 47 deletions.
171 changes: 128 additions & 43 deletions pytket/extensions/cutensornet/mps/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,13 @@ def canonicalise_tensor(self, pos: int, form: DirectionMPS) -> None:

# Contract R into Tnext
subscripts = R_bonds + "," + Tnext_bonds + "->" + result_bonds
result = cq.contract(subscripts, R, Tnext)
result = cq.contract(
subscripts,
R,
Tnext,
options=options,
optimize={"path": [(0, 1)]},
)
self._logger.debug(f"Contraction with {next_pos} applied.")

# Update self.tensors
Expand Down Expand Up @@ -502,38 +508,69 @@ def vdot(self, other: MPS) -> complex:
self._flush()
other._flush()

# Special case if only one tensor remains
if len(self) == 1:
self._logger.debug("Applying trivial vdot on single tensor MPS.")
result = cq.contract("LRp,lrp->", self.tensors[0].conj(), other.tensors[0])

else:
self._logger.debug("Applying vdot between two MPS.")

# The two MPS will be contracted from left to right, storing the
# ``partial_result`` tensor.
partial_result = cq.contract(
"LRp,lrp->Rr", self.tensors[0].conj(), other.tensors[0]
)
# Contract all tensors in the middle
for pos in range(1, len(self) - 1):
partial_result = cq.contract(
"Ll,LRp,lrp->Rr",
partial_result,
self.tensors[pos].conj(),
other.tensors[pos],
)
# Finally, contract the last tensor
result = cq.contract(
"Ll,LRp,lrp->", # R and r are dim 1, so they are omitted; scalar result
partial_result,
self.tensors[-1].conj(),
other.tensors[-1],
)
self._logger.debug("Applying vdot between two MPS.")

# We convert both MPS to their interleaved representation and
# contract them using cuQuantum.
mps1 = self._get_interleaved_representation(conj=True)
mps2 = other._get_interleaved_representation(conj=False)
interleaved_rep = mps1 + mps2
interleaved_rep.append([]) # Discards dim=1 bonds with []

# We define the contraction path ourselves
end_mps1 = len(self) - 1 # Rightmost tensor of mps1 in interleaved_rep
end_mps2 = len(self) + len(other) - 1 # Rightmost tensor of mps2
contraction_path = [(end_mps1, end_mps2)] # Contract ends of mps1 and mps2
for _ in range(len(self) - 1):
# Update the position markers
end_mps1 -= 1 # One tensor was removed from mps1
end_mps2 -= 2 # One tensor removed from mps1 and another from mps2
# Contract the result from last iteration with the ends of mps1 and mps2
contraction_path.append((end_mps2, end_mps2 + 1)) # End of mps2 and result
contraction_path.append((end_mps1, end_mps2)) # End of mps1 and ^ outcome

# Apply the contraction
result = cq.contract(
*interleaved_rep,
options={"handle": self._lib.handle, "device_id": self._lib.device_id},
optimize={"path": contraction_path},
)

self._logger.debug(f"Result from vdot={result}")
return complex(result)

def _get_interleaved_representation(
self, conj: bool = False
) -> list[Union[cp.ndarray, str]]:
"""Returns the interleaved representation of the MPS used by cuQuantum.
Args:
conj: If True, all tensors are conjugated and bonds IDs are prefixed
with * (except physical bonds). Defaults to False.
"""
self._logger.debug("Creating interleaved representation...")

# Auxiliar dictionary of physical bonds to qubit IDs
qubit_id = {location: qubit for qubit, location in self.qubit_position.items()}

interleaved_rep = []
for i, t in enumerate(self.tensors):
# Append the tensor
if conj:
interleaved_rep.append(t.conj())
else:
interleaved_rep.append(t)

# Create the ID for the bonds involved
bonds = [str(i), str(i + 1), str(qubit_id[i])]
if conj:
bonds[0] = "*" + bonds[0]
bonds[1] = "*" + bonds[1]
interleaved_rep.append(bonds)
self._logger.debug(f"Bond IDs: {bonds}")

return interleaved_rep

def sample(self) -> dict[Qubit, int]:
"""Returns a sample from a Z measurement applied on every qubit.
Expand Down Expand Up @@ -602,11 +639,13 @@ def measure(self, qubits: set[Qubit]) -> dict[Qubit, int]:
# Since the MPS is in canonical form, this corresponds to the probability
# if we were to take all of the other tensors into account.
prob = cq.contract(
"lrp,lrP,p,P->", # No open bonds remain; this is just a scalar
"lrp,p,lrP,P->", # No open bonds remain; this is just a scalar
self.tensors[pos].conj(),
self.tensors[pos],
zero_tensor,
self.tensors[pos],
zero_tensor,
options={"handle": self._lib.handle, "device_id": self._lib.device_id},
optimize={"path": [(0, 1), (0, 1), (0, 1)]},
)

# Throw a coin to decide measurement outcome
Expand Down Expand Up @@ -687,6 +726,8 @@ def _postselect_qubit(self, qubit: Qubit, postselection_tensor: cp.ndarray) -> N
"lrp,p->lr",
self.tensors[pos],
postselection_tensor,
options={"handle": self._lib.handle, "device_id": self._lib.device_id},
optimize={"path": [(0, 1)]},
)

# Glossary of bond IDs:
Expand All @@ -704,13 +745,17 @@ def _postselect_qubit(self, qubit: Qubit, postselection_tensor: cp.ndarray) -> N
"sv,VsP->VvP",
self.tensors[pos],
self.tensors[pos - 1],
options={"handle": self._lib.handle, "device_id": self._lib.device_id},
optimize={"path": [(0, 1)]},
)
self.canonical_form[pos - 1] = None
else: # There are no tensors on the left, contract with the one on the right
self.tensors[pos + 1] = cq.contract(
"vs,sVP->vVP",
self.tensors[pos],
self.tensors[pos + 1],
options={"handle": self._lib.handle, "device_id": self._lib.device_id},
optimize={"path": [(0, 1)]},
)
self.canonical_form[pos + 1] = None

Expand Down Expand Up @@ -758,7 +803,14 @@ def expectation_value(self, pauli_string: QubitPauliString) -> float:

# Contract the Pauli to the MPS tensor of the corresponding qubit
mps_copy.tensors[pos] = cq.contract(
"lrp,Pp->lrP", mps_copy.tensors[pos], pauli_tensor
"lrp,Pp->lrP",
mps_copy.tensors[pos],
pauli_tensor,
options={
"handle": self._lib.handle,
"device_id": self._lib.device_id,
},
optimize={"path": [(0, 1)]},
)

# Obtain the inner product
Expand Down Expand Up @@ -796,8 +848,22 @@ def get_statevector(self) -> np.ndarray:
output_bonds.append("p" + str(self.qubit_position[q]))
interleaved_rep.append(output_bonds)

# We define the contraction path ourselves
end_mps = len(self) - 1
contraction_path = [(end_mps - 1, end_mps)] # Contract the last two tensors
end_mps -= 2 # Two tensors removed from the MPS
for _ in range(len(self) - 2):
# Contract the result from last iteration and the last tensor in the MPS
contraction_path.append((end_mps, end_mps + 1))
# Update the position marker
end_mps -= 1 # One tensor was removed from the MPS

# Contract
result_tensor = cq.contract(*interleaved_rep)
result_tensor = cq.contract(
*interleaved_rep,
options={"handle": self._lib.handle, "device_id": self._lib.device_id},
optimize={"path": contraction_path},
)

# Convert to numpy vector and flatten
statevector: np.ndarray = cp.asnumpy(result_tensor).flatten()
Expand All @@ -818,6 +884,9 @@ def get_amplitude(self, state: int) -> complex:
The amplitude of the computational state in the MPS.
"""

# Auxiliar dictionary of physical bonds to qubit IDs
qubit_id = {location: qubit for qubit, location in self.qubit_position.items()}

# Find out what the map MPS_position -> bit value is
ilo_qubits = sorted(self.qubit_position.keys())
mps_pos_bitvalue = dict()
Expand All @@ -827,21 +896,37 @@ def get_amplitude(self, state: int) -> complex:
bitvalue = 1 if state & 2 ** (len(self) - i - 1) else 0
mps_pos_bitvalue[pos] = bitvalue

# Carry out the contraction, starting from a dummy tensor
result_tensor = cp.ones(1, dtype=self._cfg._complex_t) # rank-1, dimension 1

# Create the interleaved representation including all postselection tensors
interleaved_rep = self._get_interleaved_representation()
for pos in range(len(self)):
postselection_tensor = cp.zeros(2, dtype=self._cfg._complex_t)
postselection_tensor[mps_pos_bitvalue[pos]] = 1
# Contract postselection with qubit into the result_tensor
result_tensor = cq.contract(
"l,lrp,p->r", result_tensor, self.tensors[pos], postselection_tensor
)
interleaved_rep.append(postselection_tensor)
interleaved_rep.append([str(qubit_id[pos])])
# Append [] so that all dim=1 bonds are ignored in the result of contract
interleaved_rep.append([])

# We define the contraction path ourselves
end_mps = len(self) - 1 # Rightmost tensor of MPS in interleaved_rep
end_rep = 2 * len(self) - 1 # Last position in the representation
contraction_path = [(end_mps, end_rep)] # Contract ends
for _ in range(len(self) - 1):
# Update the position markers
end_mps -= 1 # One tensor was removed from mps
end_rep -= 2 # One tensor removed from mps and another from postselect
# Contract the result from last iteration with the ends
contraction_path.append((end_mps, end_rep + 1)) # End of mps and result
contraction_path.append((end_rep - 1, end_rep)) # End of mps1 and ^ outcome

# Apply the contraction
result = cq.contract(
*interleaved_rep,
options={"handle": self._lib.handle, "device_id": self._lib.device_id},
optimize={"samples": 1},
)

assert result_tensor.shape == (1,)
result = complex(result_tensor[0])
self._logger.debug(f"Amplitude of state {state} is {result}.")
return result
return complex(result)

def get_qubits(self) -> set[Qubit]:
"""Returns the set of qubits that this MPS is defined on."""
Expand Down
12 changes: 11 additions & 1 deletion pytket/extensions/cutensornet/mps/mps_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def _apply_1q_gate(self, position: int, gate: Op) -> MPSxGate:
gate_bonds + "," + T_bonds + "->" + result_bonds,
gate_tensor,
self.tensors[position],
options={"handle": self._lib.handle, "device_id": self._lib.device_id},
optimize={"path": [(0, 1)]},
)

# Update ``self.tensors``
Expand Down Expand Up @@ -145,6 +147,8 @@ def _apply_2q_gate(self, positions: tuple[int, int], gate: Op) -> MPSxGate:
gate_tensor,
self.tensors[l_pos],
self.tensors[r_pos],
options={"handle": self._lib.handle, "device_id": self._lib.device_id},
optimize={"path": [(0, 1), (0, 1)]},
)
self._logger.debug(f"Intermediate tensor of size (MiB)={T.nbytes / 2**20}")

Expand Down Expand Up @@ -223,7 +227,13 @@ def _apply_2q_gate(self, positions: tuple[int, int], gate: Op) -> MPSxGate:
# Use some einsum index magic: since the virtual bond "s" appears in the
# list of bonds of the output, it is not summed over.
# This causes S to act as the intended diagonal matrix.
L = cq.contract("asL,s->asL", L, S)
L = cq.contract(
"asL,s->asL",
L,
S,
options={"handle": self._lib.handle, "device_id": self._lib.device_id},
optimize={"path": [(0, 1)]},
)

# We multiply the fidelity of the current step to the overall fidelity
# to keep track of a lower bound for the fidelity.
Expand Down
27 changes: 24 additions & 3 deletions pytket/extensions/cutensornet/mps/mps_mpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def _apply_1q_gate(self, position: int, gate: Op) -> MPSxMPO:
"go," + last_bonds + "->" + new_bonds,
gate_tensor,
last_tensor,
options={"handle": self._lib.handle, "device_id": self._lib.device_id},
optimize={"path": [(0, 1)]},
)

# Update the tensor
Expand Down Expand Up @@ -385,7 +387,11 @@ def update_sweep_cache(pos: int, direction: DirectionMPS) -> None:
interleaved_rep.append(["r", "R"] + result_bonds)

# Contract and store
T = cq.contract(*interleaved_rep)
T = cq.contract(
*interleaved_rep,
options={"handle": self._lib.handle, "device_id": self._lib.device_id},
optimize={"samples": 1},
)
if direction == DirectionMPS.LEFT:
r_cached_tensors.append(T)
elif direction == DirectionMPS.RIGHT:
Expand Down Expand Up @@ -443,10 +449,25 @@ def update_variational_tensor(
interleaved_rep.append(result_bonds)

# Contract and store tensor
F = cq.contract(*interleaved_rep)
F = cq.contract(
*interleaved_rep,
options={"handle": self._lib.handle, "device_id": self._lib.device_id},
optimize={"samples": 1},
)

# Get the fidelity
optim_fidelity = complex(cq.contract("LRP,LRP->", F.conj(), F))
optim_fidelity = complex(
cq.contract(
"LRP,LRP->",
F.conj(),
F,
options={
"handle": self._lib.handle,
"device_id": self._lib.device_id,
},
optimize={"path": [(0, 1)]},
)
)
assert np.isclose(optim_fidelity.imag, 0.0, atol=self._cfg._atol)
optim_fidelity = float(optim_fidelity.real)

Expand Down

0 comments on commit d0886c0

Please sign in to comment.