diff --git a/pytket/extensions/cutensornet/backends/cutensornet_backend.py b/pytket/extensions/cutensornet/backends/cutensornet_backend.py index b6e13757..d86a3f67 100644 --- a/pytket/extensions/cutensornet/backends/cutensornet_backend.py +++ b/pytket/extensions/cutensornet/backends/cutensornet_backend.py @@ -238,8 +238,6 @@ def process_circuits( handle_list.append(handle) return handle_list - # TODO: this should be optionally parallelised with MPI - # (both wrt Pauli strings and contraction itself). def get_operator_expectation_value( self, state_circuit: Circuit, @@ -261,13 +259,46 @@ def get_operator_expectation_value( Returns: Expectation value. """ + return self.get_matrix_element( + state_circuit, + state_circuit, + operator, + post_selection=post_selection, + valid_check=valid_check, + ).real + + # TODO: this should be optionally parallelised with MPI + # (both wrt Pauli strings and contraction itself). + def get_matrix_element( + self, + circuit_bra: Circuit, + circuit_ket: Circuit, + operator: QubitPauliOperator, + post_selection: Optional[dict[Qubit, int]] = None, + valid_check: bool = True, + ) -> float: + """Calculates a general matrix element using cuTensorNet contraction. + + Has an option to do post selection on an ancilla register. + + Args: + circuit_bra: Circuit representing bra state. + circuit_ket: Circuit representing ket state. + operator: Operator which matrix element is to be calculated. + valid_check: Whether to perform circuit validity check. + post_selection: Dictionary of qubits to post select where the key is + qubit and the value is bit outcome. + + Returns: + Matrix element. + """ if valid_check: - self._check_all_circuits([state_circuit]) + self._check_all_circuits([circuit_bra, circuit_ket]) - expectation = 0 + element = 0 - ket_network = TensorNetwork(state_circuit) - bra_network = ket_network.dagger() + ket_network = TensorNetwork(circuit_ket) + bra_network = TensorNetwork(circuit_bra).dagger() if post_selection is not None: post_select_qubits = list(post_selection.keys()) @@ -281,18 +312,18 @@ def get_operator_expectation_value( ) # This needed because dagger does not work with post selection for qos, coeff in operator._dict.items(): - expectation_value_network = ExpectationValueTensorNetwork( + element_network = ExpectationValueTensorNetwork( bra_network, qos, ket_network ) if isinstance(coeff, Expr): numeric_coeff = complex(coeff.evalf()) # type: ignore else: numeric_coeff = complex(coeff) # type: ignore - expectation_term = numeric_coeff * cq.contract( - *expectation_value_network.cuquantum_interleaved + element_term = numeric_coeff * cq.contract( + *element_network.cuquantum_interleaved ) - expectation += expectation_term - return expectation.real + element += element_term + return element def get_circuit_overlap( self,