Skip to content

Commit

Permalink
Sorting gates before simulation is now switched off by default. Can b…
Browse files Browse the repository at this point in the history
…e switched on via an undocumented flag.
  • Loading branch information
PabloAndresCQ committed Sep 6, 2024
1 parent 8caee93 commit 0b5667a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 21 deletions.
5 changes: 0 additions & 5 deletions pytket/extensions/cutensornet/structured_state/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(
float_precision: Type[Any] = np.float64,
value_of_zero: float = 1e-16,
leaf_size: int = 8,
use_kahypar: bool = False,
k: int = 4,
optim_delta: float = 1e-5,
loglevel: int = logging.WARNING,
Expand Down Expand Up @@ -92,9 +91,6 @@ def __init__(
``np.float64`` precision (default) and ``1e-7`` for ``np.float32``.
leaf_size: For ``TTN`` simulation only. Sets the maximum number of
qubits in a leaf node when using ``TTN``. Default is 8.
use_kahypar: Use KaHyPar for graph partitioning (used in ``TTN``) if this
is True. Otherwise, use NetworkX (worse, but easy to setup). Defaults
to False.
k: For ``MPSxMPO`` simulation only. Sets the maximum number of layers
the MPO is allowed to have before being contracted. Increasing this
might increase fidelity, but it will also increase resource requirements
Expand Down Expand Up @@ -161,7 +157,6 @@ def __init__(
raise ValueError("Maximum allowed leaf_size is 65.")

self.leaf_size = leaf_size
self.use_kahypar = use_kahypar
self.k = k
self.optim_delta = 1e-5
self.loglevel = loglevel
Expand Down
57 changes: 41 additions & 16 deletions pytket/extensions/cutensornet/structured_state/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, Any
import warnings
from enum import Enum

Expand All @@ -26,7 +26,7 @@
except ImportError:
warnings.warn("local settings failed to import kahypar", ImportWarning)

from pytket.circuit import Circuit, Command, Qubit
from pytket.circuit import Circuit, Command, OpType, Qubit
from pytket.transform import Transform
from pytket.architecture import Architecture
from pytket.passes import DefaultMappingPass
Expand Down Expand Up @@ -56,6 +56,7 @@ def simulate(
circuit: Circuit,
algorithm: SimulationAlgorithm,
config: Config,
compilation_params: Optional[dict[str, Any]] = None,
) -> StructuredState:
"""Simulates the circuit and returns the ``StructuredState`` of the final state.
Expand All @@ -73,25 +74,26 @@ def simulate(
circuit: The pytket circuit to be simulated.
algorithm: Choose between the values of the ``SimulationAlgorithm`` enum.
config: The configuration object for simulation.
compilation_params: Experimental feature. Defaults to no compilation.
Parameters currently not documented.
Returns:
An instance of ``StructuredState`` for (an approximation of) the final state
of the circuit. The instance be of the class matching ``algorithm``.
"""
logger = set_logger("Simulation", level=config.loglevel)

logger.info(
"Ordering the gates in the circuit to reduce canonicalisation overhead."
)
if compilation_params is None:
compilation_params = dict()

# Initialise the StructuredState
if algorithm == SimulationAlgorithm.MPSxGate:
state = MPSxGate( # type: ignore
libhandle,
circuit.qubits,
bits=circuit.bits,
config=config,
)
# TODO: Currently deactivating gate sorting. Fix this before merging branch
sorted_gates = circuit.get_commands() # _get_sorted_gates(circuit, algorithm)

elif algorithm == SimulationAlgorithm.MPSxMPO:
state = MPSxMPO( # type: ignore
Expand All @@ -100,29 +102,40 @@ def simulate(
bits=circuit.bits,
config=config,
)
# TODO: Currently deactivating gate sorting. Fix this before merging branch
sorted_gates = circuit.get_commands() # _get_sorted_gates(circuit, algorithm)

elif algorithm == SimulationAlgorithm.TTNxGate:
use_kahypar_option: bool = compilation_params.get("use_kahypar", False)

qubit_partition = _get_qubit_partition(
circuit, config.leaf_size, config.use_kahypar
circuit, config.leaf_size, use_kahypar_option
)
state = TTNxGate( # type: ignore
libhandle,
qubit_partition,
bits=circuit.bits,
config=config,
)
# TODO: Currently deactivating gate sorting. Fix this before merging branch
sorted_gates = (
circuit.get_commands()
) # _get_sorted_gates(circuit, algorithm, qubit_partition)

# If requested by the user, sort the gates to reduce canonicalisation overhead.
sort_gates_option: bool = compilation_params.get("sort_gates", False)
if sort_gates_option:
logger.info(
"Ordering the gates in the circuit to reduce canonicalisation overhead."
)

if algorithm == SimulationAlgorithm.TTNxGate:
commands = _get_sorted_gates(circuit, algorithm, qubit_partition)
else:
commands = _get_sorted_gates(circuit, algorithm)
else:
commands = circuit.get_commands()

# Run the simulation
logger.info("Running simulation...")
# Apply the gates
for i, g in enumerate(sorted_gates):
for i, g in enumerate(commands):
state.apply_gate(g)
logger.info(f"Progress... {(100*i) // len(sorted_gates)}%")
logger.info(f"Progress... {(100*i) // len(commands)}%")

# Apply the batched operations that are left (if any)
state._flush()
Expand Down Expand Up @@ -323,6 +336,13 @@ def _get_sorted_gates(
2-qubit gates that are close together are applied one after the other. This reduces
the overhead of canonicalisation during simulation.
Notes:
If the circuit has any command (other than measurement) acting on bits, this
function gives up trying to sort the gates, and simply returns the standard
`circuit.get_commands()`. It would be possible to update this function so that
it can manage these commands as well, but it is not clear that there is a strong
use case for this.
Args:
circuit: The original circuit.
algorithm: The simulation algorithm that will be used on this circuit.
Expand All @@ -334,6 +354,11 @@ def _get_sorted_gates(
The same gates, ordered in a beneficial way for the given algorithm.
"""
all_gates = circuit.get_commands()

# Abort if there is classical logic or classical control in the circuit (see note)
if any(len(g.bits) != 0 and g.op.type is not OpType.Measure for g in all_gates):
return all_gates

sorted_gates = []
# Entries from `all_gates` that are not yet in `sorted_gates`
remaining = set(range(len(all_gates)))
Expand Down

0 comments on commit 0b5667a

Please sign in to comment.