diff --git a/pytket/extensions/cutensornet/structured_state/general.py b/pytket/extensions/cutensornet/structured_state/general.py index 903cfa3c..ff14727d 100644 --- a/pytket/extensions/cutensornet/structured_state/general.py +++ b/pytket/extensions/cutensornet/structured_state/general.py @@ -55,7 +55,6 @@ def __init__( float_precision: Type[Any] = np.float64, value_of_zero: float = 1e-16, leaf_size: int = 8, - use_kahypar: bool = False, k: int = 4, optim_delta: float = 1e-5, loglevel: int = logging.WARNING, @@ -92,9 +91,6 @@ def __init__( ``np.float64`` precision (default) and ``1e-7`` for ``np.float32``. leaf_size: For ``TTN`` simulation only. Sets the maximum number of qubits in a leaf node when using ``TTN``. Default is 8. - use_kahypar: Use KaHyPar for graph partitioning (used in ``TTN``) if this - is True. Otherwise, use NetworkX (worse, but easy to setup). Defaults - to False. k: For ``MPSxMPO`` simulation only. Sets the maximum number of layers the MPO is allowed to have before being contracted. Increasing this might increase fidelity, but it will also increase resource requirements @@ -161,7 +157,6 @@ def __init__( raise ValueError("Maximum allowed leaf_size is 65.") self.leaf_size = leaf_size - self.use_kahypar = use_kahypar self.k = k self.optim_delta = 1e-5 self.loglevel = loglevel diff --git a/pytket/extensions/cutensornet/structured_state/simulation.py b/pytket/extensions/cutensornet/structured_state/simulation.py index 855593c6..76dcb1fd 100644 --- a/pytket/extensions/cutensornet/structured_state/simulation.py +++ b/pytket/extensions/cutensornet/structured_state/simulation.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Any import warnings from enum import Enum @@ -26,7 +26,7 @@ except ImportError: warnings.warn("local settings failed to import kahypar", ImportWarning) -from pytket.circuit import Circuit, Command, Qubit +from pytket.circuit import Circuit, Command, OpType, Qubit from pytket.transform import Transform from pytket.architecture import Architecture from pytket.passes import DefaultMappingPass @@ -56,6 +56,7 @@ def simulate( circuit: Circuit, algorithm: SimulationAlgorithm, config: Config, + compilation_params: Optional[dict[str, Any]] = None, ) -> StructuredState: """Simulates the circuit and returns the ``StructuredState`` of the final state. @@ -73,6 +74,8 @@ def simulate( circuit: The pytket circuit to be simulated. algorithm: Choose between the values of the ``SimulationAlgorithm`` enum. config: The configuration object for simulation. + compilation_params: Experimental feature. Defaults to no compilation. + Parameters currently not documented. Returns: An instance of ``StructuredState`` for (an approximation of) the final state @@ -80,9 +83,10 @@ def simulate( """ logger = set_logger("Simulation", level=config.loglevel) - logger.info( - "Ordering the gates in the circuit to reduce canonicalisation overhead." - ) + if compilation_params is None: + compilation_params = dict() + + # Initialise the StructuredState if algorithm == SimulationAlgorithm.MPSxGate: state = MPSxGate( # type: ignore libhandle, @@ -90,8 +94,6 @@ def simulate( bits=circuit.bits, config=config, ) - # TODO: Currently deactivating gate sorting. Fix this before merging branch - sorted_gates = circuit.get_commands() # _get_sorted_gates(circuit, algorithm) elif algorithm == SimulationAlgorithm.MPSxMPO: state = MPSxMPO( # type: ignore @@ -100,12 +102,12 @@ def simulate( bits=circuit.bits, config=config, ) - # TODO: Currently deactivating gate sorting. Fix this before merging branch - sorted_gates = circuit.get_commands() # _get_sorted_gates(circuit, algorithm) elif algorithm == SimulationAlgorithm.TTNxGate: + use_kahypar_option: bool = compilation_params.get("use_kahypar", False) + qubit_partition = _get_qubit_partition( - circuit, config.leaf_size, config.use_kahypar + circuit, config.leaf_size, use_kahypar_option ) state = TTNxGate( # type: ignore libhandle, @@ -113,16 +115,27 @@ def simulate( bits=circuit.bits, config=config, ) - # TODO: Currently deactivating gate sorting. Fix this before merging branch - sorted_gates = ( - circuit.get_commands() - ) # _get_sorted_gates(circuit, algorithm, qubit_partition) + # If requested by the user, sort the gates to reduce canonicalisation overhead. + sort_gates_option: bool = compilation_params.get("sort_gates", False) + if sort_gates_option: + logger.info( + "Ordering the gates in the circuit to reduce canonicalisation overhead." + ) + + if algorithm == SimulationAlgorithm.TTNxGate: + commands = _get_sorted_gates(circuit, algorithm, qubit_partition) + else: + commands = _get_sorted_gates(circuit, algorithm) + else: + commands = circuit.get_commands() + + # Run the simulation logger.info("Running simulation...") # Apply the gates - for i, g in enumerate(sorted_gates): + for i, g in enumerate(commands): state.apply_gate(g) - logger.info(f"Progress... {(100*i) // len(sorted_gates)}%") + logger.info(f"Progress... {(100*i) // len(commands)}%") # Apply the batched operations that are left (if any) state._flush() @@ -323,6 +336,13 @@ def _get_sorted_gates( 2-qubit gates that are close together are applied one after the other. This reduces the overhead of canonicalisation during simulation. + Notes: + If the circuit has any command (other than measurement) acting on bits, this + function gives up trying to sort the gates, and simply returns the standard + `circuit.get_commands()`. It would be possible to update this function so that + it can manage these commands as well, but it is not clear that there is a strong + use case for this. + Args: circuit: The original circuit. algorithm: The simulation algorithm that will be used on this circuit. @@ -334,6 +354,11 @@ def _get_sorted_gates( The same gates, ordered in a beneficial way for the given algorithm. """ all_gates = circuit.get_commands() + + # Abort if there is classical logic or classical control in the circuit (see note) + if any(len(g.bits) != 0 and g.op.type is not OpType.Measure for g in all_gates): + return all_gates + sorted_gates = [] # Entries from `all_gates` that are not yet in `sorted_gates` remaining = set(range(len(all_gates)))