Skip to content

Commit

Permalink
Cleaned up code around bonds_from_q0_to_ancestor
Browse files Browse the repository at this point in the history
  • Loading branch information
PabloAndresCQ committed Feb 1, 2024
1 parent 2f86a77 commit 96cca32
Showing 1 changed file with 24 additions and 23 deletions.
47 changes: 24 additions & 23 deletions pytket/extensions/cutensornet/tnstate/ttn_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,26 @@ def _apply_2q_gate(self, q0: Qubit, q1: Qubit, gate: Op) -> TTNxGate:
# Once `msg_tensor` is directly connected to the leaf node containing `q1`, we
# just need to contract them, connecting `b` to `q1`, with `B` becoming the
# new physical bond.
#
bonds_to_q0 = [ # Bonds in the "arc" from the common ancestor to `q0`
path_q0[:i] for i in range(len(common_path) + 1, len(path_q0) + 1)
]
# Sanity checks:
assert all(
len(bond_address) != len(common_path) for bond_address in bonds_to_q0
)
assert len(bonds_to_q0) == 1 or len(bonds_to_q0[0]) < len(bonds_to_q0[1])
assert len(bonds_to_q0[-1]) == len(path_q0)

bonds_to_q1 = [ # Bonds in the "arc" from the common ancestor to `q1`
path_q1[:i] for i in range(len(common_path) + 1, len(path_q1) + 1)
]
# Sanity checks:
assert all(
len(bond_address) != len(common_path) for bond_address in bonds_to_q1
)
assert len(bonds_to_q1) == 1 or len(bonds_to_q1[0]) < len(bonds_to_q1[1])
assert len(bonds_to_q1[-1]) == len(path_q1)

# The `msg_tensor` has four bonds. Our convention will be that the first bond
# always corresponds to `B`, the second bond is `b`, the third bond connects
# it to the TTN in the child direction and the fourth connects it to the TTN
Expand Down Expand Up @@ -221,12 +240,7 @@ def _apply_2q_gate(self, q0: Qubit, q1: Qubit, gate: Op) -> TTNxGate:

# We must push the `msg_tensor` all the way to the common ancestor
# of `q0` and `q1`.
bond_addresses = [
path_q0[:i] for i in reversed(range(len(common_path) + 1, len(path_q0) + 1))
]
# Sanity checks:
assert all(len(root_path) != len(common_path) for root_path in bond_addresses)
assert len(bond_addresses[0]) == len(path_q0)
bond_addresses = reversed(bonds_to_q0) # From `q0` to the ancestor

# For all of these nodes; push `msg_tensor` through to their parent bond
for child_bond in bond_addresses[:-1]: # Doesn't do it on common ancestor!
Expand Down Expand Up @@ -273,7 +287,6 @@ def _apply_2q_gate(self, q0: Qubit, q1: Qubit, gate: Op) -> TTNxGate:
f"({common_ancestor_node.tensor.nbytes // 2**20} MiB) at {parent_bond}."
)


# Apply the contraction followed by a QR decomposition
common_ancestor_node.tensor, msg_tensor = contract_decompose(
f"{node_bonds},{msg_bonds}->{Q_bonds},{R_bonds}",
Expand All @@ -291,12 +304,7 @@ def _apply_2q_gate(self, q0: Qubit, q1: Qubit, gate: Op) -> TTNxGate:

# We must push the `msg_tensor` from the common ancestor to the leaf node
# containing `q1`.
bond_addresses = [
path_q1[:i] for i in range(len(common_path) + 1, len(path_q1) + 1)
]
# Sanity checks:
assert all(len(root_path) != len(common_path) for root_path in bond_addresses)
assert len(bond_addresses[-1]) == len(path_q1)
bond_addresses = bonds_to_q1 # From ancestor to `q1`

# For all of these nodes; push `msg_tensor` through to their child bond
for child_bond in bond_addresses[1:]: # Skip common ancestor: already pushed
Expand Down Expand Up @@ -351,23 +359,16 @@ def _apply_2q_gate(self, q0: Qubit, q1: Qubit, gate: Op) -> TTNxGate:
# Truncate (if needed) bonds along the arc from `q1` to `q0`.
# We truncate in this direction to take advantage of the canonicalisation
# of the TTN we achieved while pushing the `msg_tensor` from `q0` to `q1`.
bonds_from_q1_to_ancestor = [
path_q1[:i] for i in reversed(range(len(common_path) + 1, len(path_q1) + 1))
]
bonds_from_ancestor_to_q0 = [
path_q0[:i] for i in range(len(common_path) + 1, len(path_q0) + 1)
]

if self._cfg.truncation_fidelity < 1:
# Truncate as much as possible before violating the truncation fidelity
self._fidelity_bound_sequential_weighted_truncation(
bonds_from_q1_to_ancestor, bonds_from_ancestor_to_q0
reversed(bonds_to_q1), bonds_to_q0
)

else:
# Truncate so that all bonds have dimension less or equal to chi
self._chi_sequential_truncation(
bonds_from_q1_to_ancestor, bonds_from_ancestor_to_q0
reversed(bonds_to_q1), bonds_to_q0
)

return self
Expand Down

0 comments on commit 96cca32

Please sign in to comment.