Skip to content

Commit

Permalink
Support ClExpr (#176)
Browse files Browse the repository at this point in the history
# Description

Support for `ClExpr`

# Related issues

#175 

# Checklist

- [x] I have performed a self-review of my code.
- [x] I have commented hard-to-understand parts of my code.
- [x] I have made corresponding changes to the public API documentation.
- [x] I have added tests that prove my fix is effective or that my
feature works.
- [x] I have updated the changelog with any user-facing changes.
  • Loading branch information
PabloAndresCQ authored Nov 12, 2024
1 parent 935f509 commit 1e12786
Show file tree
Hide file tree
Showing 4 changed files with 381 additions and 16 deletions.
6 changes: 6 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
Changelog
~~~~~~~~~

Unreleased
----------

* Updated pytket version requirement to 1.34.
* Now supporting ``ClExpr`` operations (the new version of tket's ``ClassicalExpBox``).

0.10.0 (October 2024)
---------------------

Expand Down
126 changes: 126 additions & 0 deletions pytket/extensions/cutensornet/structured_state/classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
SetBitsOp,
CopyBitsOp,
RangePredicateOp,
ClExprOp,
ClassicalExpBox,
LogicExp,
BitWiseOp,
RegWiseOp,
)
from pytket._tket.circuit import ClExpr, ClOp, ClBitVar, ClRegVar


ExtendedLogicExp = Union[LogicExp, Bit, BitRegister, int]
Expand Down Expand Up @@ -56,6 +58,40 @@ def apply_classical_command(
# Check that the value is in the range
bits_dict[res_bit] = val >= op.lower and val <= op.upper

elif isinstance(op, ClExprOp):
# Convert bit_posn to dictionary of `ClBitVar` index to its value
bitvar_val = {
var_id: int(bits_dict[args[bit_pos]])
for var_id, bit_pos in op.expr.bit_posn.items()
}
# Convert reg_posn to dictionary of `ClRegVar` index to its value
regvar_val = {
var_id: from_little_endian(
[bits_dict[args[bit_pos]] for bit_pos in reg_pos_list]
)
for var_id, reg_pos_list in op.expr.reg_posn.items()
}
# Identify number of bits on each register
regvar_size = {
var_id: len(reg_pos_list)
for var_id, reg_pos_list in op.expr.reg_posn.items()
}
# Identify number of bits in output register
output_size = len(op.expr.output_posn)
result = evaluate_clexpr(
op.expr.expr, bitvar_val, regvar_val, regvar_size, output_size
)

# The result is an int in little-endian encoding. We update the
# output register accordingly.
for bit_pos in op.expr.output_posn:
bits_dict[args[bit_pos]] = (result % 2) == 1
result = result >> 1
# If there has been overflow in the operations, error out.
# This can be detected if `result != 0`
if result != 0:
raise ValueError("Evaluation of the ClExpr resulted in overflow.")

elif isinstance(op, ClassicalExpBox):
the_exp = op.get_exp()
result = evaluate_logic_exp(the_exp, bits_dict)
Expand All @@ -74,6 +110,95 @@ def apply_classical_command(
raise NotImplementedError(f"Commands of type {op.type} are not supported.")


def evaluate_clexpr(
expr: ClExpr,
bitvar_val: dict[int, int],
regvar_val: dict[int, int],
regvar_size: dict[int, int],
output_size: int,
) -> int:
"""Recursive evaluation of a ClExpr."""

# Evaluate arguments to operation
args_val = []
for arg in expr.args:
if isinstance(arg, int):
value = arg
elif isinstance(arg, ClBitVar):
value = bitvar_val[arg.index]
elif isinstance(arg, ClRegVar):
value = regvar_val[arg.index]
elif isinstance(arg, ClExpr):
value = evaluate_clexpr(
arg, bitvar_val, regvar_val, regvar_size, output_size
)
else:
raise Exception(f"Unrecognised argument type of ClExpr: {type(arg)}.")

args_val.append(value)

# Apply the operation at the root of this ClExpr
if expr.op in [ClOp.BitAnd, ClOp.RegAnd]:
result = args_val[0] & args_val[1]
elif expr.op in [ClOp.BitOr, ClOp.RegOr]:
result = args_val[0] | args_val[1]
elif expr.op in [ClOp.BitXor, ClOp.RegXor]:
result = args_val[0] ^ args_val[1]
elif expr.op in [ClOp.BitEq, ClOp.RegEq]:
result = int(args_val[0] == args_val[1])
elif expr.op in [ClOp.BitNeq, ClOp.RegNeq]:
result = int(args_val[0] != args_val[1])
elif expr.op == ClOp.RegGeq:
result = int(args_val[0] >= args_val[1])
elif expr.op == ClOp.RegGt:
result = int(args_val[0] > args_val[1])
elif expr.op == ClOp.RegLeq:
result = int(args_val[0] <= args_val[1])
elif expr.op == ClOp.RegLt:
result = int(args_val[0] < args_val[1])
elif expr.op == ClOp.BitNot:
result = 1 - args_val[0]
elif expr.op == ClOp.RegNot: # Bit-wise NOT (flip all bits)
n_bits = regvar_size[expr.args[0].index] # type: ignore
result = (2**n_bits - 1) ^ args_val[0] # XOR with all 1s bitstring
elif expr.op in [ClOp.BitZero, ClOp.RegZero]:
result = 0
elif expr.op == ClOp.BitOne:
result = 1
elif expr.op == ClOp.RegOne: # All 1s bitstring
n_bits = output_size
result = 2**n_bits - 1
elif expr.op == ClOp.RegAdd:
result = args_val[0] + args_val[1]
elif expr.op == ClOp.RegSub:
if args_val[0] < args_val[1]:
raise NotImplementedError(
"Currently not supporting ClOp.RegSub where the outcome is negative."
)
result = args_val[0] - args_val[1]
elif expr.op == ClOp.RegMul:
result = args_val[0] * args_val[1]
elif expr.op == ClOp.RegDiv: # floor(a / b)
result = args_val[0] // args_val[1]
elif expr.op == ClOp.RegPow:
result = int(args_val[0] ** args_val[1])
elif expr.op == ClOp.RegLsh:
result = args_val[0] << args_val[1]
elif expr.op == ClOp.RegRsh:
result = args_val[0] >> args_val[1]
# elif expr.op == ClOp.RegNeg:
# result = -args_val[0]
else:
# TODO: Not supporting RegNeg because I do not know if we have agreed how to
# specify signed ints.
raise NotImplementedError(
f"Evaluation of {expr.op} not supported in ClExpr ",
"by pytket-cutensornet.",
)

return result


def evaluate_logic_exp(exp: ExtendedLogicExp, bits_dict: dict[Bit, bool]) -> int:
"""Recursive evaluation of a LogicExp."""

Expand Down Expand Up @@ -132,4 +257,5 @@ def evaluate_logic_exp(exp: ExtendedLogicExp, bits_dict: dict[Bit, bool]) -> int
def from_little_endian(bitstring: list[bool]) -> int:
"""Obtain the integer from the little-endian encoded bitstring (i.e. bitstring
[False, True] is interpreted as the integer 2)."""
# TODO: Assumes unisigned integer. What are the specs for signed integers?
return sum(1 << i for i, b in enumerate(bitstring) if b)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
license="Apache 2",
packages=find_namespace_packages(include=["pytket.*"]),
include_package_data=True,
install_requires=["pytket >= 1.33.0", "networkx >= 2.8.8"],
install_requires=["pytket >= 1.34.0", "networkx >= 2.8.8"],
classifiers=[
"Environment :: Console",
"Programming Language :: Python :: 3.10",
Expand Down
Loading

0 comments on commit 1e12786

Please sign in to comment.