Skip to content

Commit

Permalink
Updating tensorflow tests to work with lab dev (#533)
Browse files Browse the repository at this point in the history
**Context:** Optimizer and training tests need to be updated for the new
`lab_dev` classes.

**Description of the Change:** Updated `test_opt.py` to use `lab_dev`
including any necessary fixes
  • Loading branch information
apchytr authored and ziofil committed Jan 22, 2025
1 parent 9e1dcb7 commit 5bf10ce
Show file tree
Hide file tree
Showing 12 changed files with 646 additions and 47 deletions.
2 changes: 1 addition & 1 deletion mrmustard/lab_dev/states/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .bargmann_eigenstate import BargmannEigenstate
from .coherent import Coherent
from .displaced_squeezed import DisplacedSqueezed
from .gstate import GKet, GDM
from .gaussian_state import GKet, GDM
from .number import Number
from .quadrature_eigenstate import QuadratureEigenstate
from .squeezed_vacuum import SqueezedVacuum
Expand Down
File renamed without changes.
16 changes: 9 additions & 7 deletions mrmustard/lab_dev/transformations/rgate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Rgate(Unitary):
>>> import numpy as np
>>> from mrmustard.lab_dev import Rgate
>>> unitary = Rgate(modes=[1, 2], phi=0.1)
>>> unitary = Rgate(modes=[1, 2], theta=0.1)
>>> assert unitary.modes == [1, 2]
Args:
Expand All @@ -55,17 +55,19 @@ class Rgate(Unitary):
def __init__(
self,
modes: Sequence[int],
phi: float | Sequence[float] = 0.0,
phi_trainable: bool = False,
phi_bounds: tuple[float | None, float | None] = (0.0, None),
theta: float | Sequence[float] = 0.0,
theta_trainable: bool = False,
theta_bounds: tuple[float | None, float | None] = (0.0, None),
):
super().__init__(name="Rgate")
(phis,) = list(reshape_params(len(modes), phi=phi))
self.parameters.add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds))
(thetas,) = list(reshape_params(len(modes), theta=theta))
self.parameters.add_parameter(
make_parameter(theta_trainable, thetas, "theta", theta_bounds)
)
self._representation = self.from_ansatz(
modes_in=modes,
modes_out=modes,
ansatz=PolyExpAnsatz.from_function(
fn=triples.rotation_gate_Abc, theta=self.parameters.phi
fn=triples.rotation_gate_Abc, theta=self.parameters.theta
),
).representation
14 changes: 13 additions & 1 deletion mrmustard/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
update_symplectic,
update_unitary,
)
from mrmustard.lab import Circuit

import mrmustard.lab as mrml

from mrmustard.lab_dev import Circuit

__all__ = ["Optimizer"]

Expand Down Expand Up @@ -188,6 +191,11 @@ def _get_trainable_params(trainable_items, root_tag: str = "optimized"):
for i, item in enumerate(trainable_items):
owner_tag = f"{root_tag}[{i}]"
if isinstance(item, Circuit):
for j, op in enumerate(item.components):
tag = f"{owner_tag}:{item.__class__.__qualname__}/_ops[{j}]"
tagged_vars = op.parameters.tagged_variables(tag)
trainables.append(tagged_vars.items())
elif isinstance(item, mrml.Circuit):
for j, op in enumerate(item.ops):
tag = f"{owner_tag}:{item.__class__.__qualname__}/_ops[{j}]"
tagged_vars = op.parameter_set.tagged_variables(tag)
Expand All @@ -196,6 +204,10 @@ def _get_trainable_params(trainable_items, root_tag: str = "optimized"):
tag = f"{owner_tag}:{item.__class__.__qualname__}"
tagged_vars = item.parameter_set.tagged_variables(tag)
trainables.append(tagged_vars.items())
elif hasattr(item, "parameters"):
tag = f"{owner_tag}:{item.__class__.__qualname__}"
tagged_vars = item.parameters.tagged_variables(tag)
trainables.append(tagged_vars.items())
elif math.from_backend(item) and math.is_trainable(item):
# the created parameter is wrapped into a list because the case above
# returns a list, hence ensuring we have a list of lists
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/training/parameter_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mrmustard.utils.typing import Tensor

from mrmustard import math
from .parameter import Trainable
from ..utils.typing import Trainable


def update_symplectic(grads_and_vars: Sequence[tuple[Tensor, Trainable]], symplectic_lr: float):
Expand Down
26 changes: 13 additions & 13 deletions tests/test_lab_dev/test_transformations/test_rgate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,21 @@ class TestRgate:
"""

modes = [[0], [1, 2], [7, 9]]
phis = [[1], 1, [1, 2]]
thetas = [[1], 1, [1, 2]]

@pytest.mark.parametrize("modes,phi", zip(modes, phis))
def test_init(self, modes, phi):
gate = Rgate(modes, phi)
@pytest.mark.parametrize("modes,theta", zip(modes, thetas))
def test_init(self, modes, theta):
gate = Rgate(modes, theta)

assert gate.name == "Rgate"
assert gate.modes == [modes] if not isinstance(modes, list) else sorted(modes)

def test_init_error(self):
with pytest.raises(ValueError, match="phi"):
Rgate(modes=[0, 1], phi=[2, 3, 4])
with pytest.raises(ValueError, match="theta"):
Rgate(modes=[0, 1], theta=[2, 3, 4])

def test_representation(self):
rep1 = Rgate(modes=[0], phi=0.1).ansatz
rep1 = Rgate(modes=[0], theta=0.1).ansatz
assert math.allclose(
rep1.A,
[
Expand All @@ -56,7 +56,7 @@ def test_representation(self):
assert math.allclose(rep1.b, np.zeros((1, 2)))
assert math.allclose(rep1.c, [1.0 + 0.0j])

rep2 = Rgate(modes=[0, 1], phi=[0.1, 0.3]).ansatz
rep2 = Rgate(modes=[0, 1], theta=[0.1, 0.3]).ansatz
assert math.allclose(
rep2.A,
[
Expand All @@ -71,7 +71,7 @@ def test_representation(self):
assert math.allclose(rep2.b, np.zeros((1, 4)))
assert math.allclose(rep2.c, [1.0 + 0.0j])

rep3 = Rgate(modes=[1], phi=0.1).ansatz
rep3 = Rgate(modes=[1], theta=0.1).ansatz
assert math.allclose(
rep3.A,
[
Expand All @@ -89,11 +89,11 @@ def test_trainable_parameters(self):
gate2 = Rgate([0], 1, True, (-2, 2))

with pytest.raises(AttributeError):
gate1.parameters.phi.value = 3
gate1.parameters.theta.value = 3

gate2.parameters.phi.value = 2
assert gate2.parameters.phi.value == 2
gate2.parameters.theta.value = 2
assert gate2.parameters.theta.value == 2

def test_representation_error(self):
with pytest.raises(ValueError):
Rgate(modes=[0], phi=[0.1, 0.2]).ansatz
Rgate(modes=[0], theta=[0.1, 0.2]).ansatz
19 changes: 9 additions & 10 deletions tests/test_training/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
import tensorflow as tf

from mrmustard import math, settings
from mrmustard.lab.circuit import Circuit
from mrmustard.lab.gates import BSgate, S2gate
from mrmustard.lab.states import Vacuum
from mrmustard.lab_dev import Circuit, BSgate, S2gate, Vacuum
from mrmustard.training import Optimizer, TensorboardCallback

from ..conftest import skip_np
Expand All @@ -33,25 +31,26 @@ def test_tensorboard_callback(tmp_path):
settings.SEED = 42
i, k = 2, 3
r = np.arcsinh(1.0)
state_in = Vacuum((0, 1, 2, 3))
s2_0, s2_1, bs = (
S2gate(r=r, phi=0.0, phi_trainable=True)[0, 1],
S2gate(r=r, phi=0.0, phi_trainable=True)[2, 3],
S2gate((0, 1), r=r, phi=0.0, phi_trainable=True),
S2gate((2, 3), r=r, phi=0.0, phi_trainable=True),
BSgate(
(1, 2),
theta=np.arccos(np.sqrt(k / (i + k))) + 0.1 * settings.rng.normal(),
phi=settings.rng.normal(),
theta_trainable=True,
phi_trainable=True,
)[1, 2],
),
)
circ = Circuit([s2_0, s2_1, bs])
state_in = Vacuum(num_modes=4)
circ = Circuit([state_in, s2_0, s2_1, bs])
cutoff = 1 + i + k

free_var = math.new_variable([1.1, -0.2], None, "free_var")

def cost_fn():
return tf.abs(
(state_in >> circ).ket(cutoffs=[cutoff] * 4)[i, 1, i + k - 1, k]
circ.contract().fock_array((cutoff,) * 4)[i, 1, i + k - 1, k]
) ** 2 + tf.reduce_sum(free_var**2)

tbcb = TensorboardCallback(
Expand All @@ -64,7 +63,7 @@ def cost_fn():
opt = Optimizer(euclidean_lr=0.01)
opt.minimize(cost_fn, by_optimizing=[circ, free_var], max_steps=300, callbacks={"tb": tbcb})

assert np.allclose(np.cos(bs.theta.value) ** 2, k / (i + k), atol=1e-2)
assert np.allclose(np.cos(bs.parameters.theta.value) ** 2, k / (i + k), atol=1e-2)
assert tbcb.logdir.exists()
assert len(list(tbcb.writter_logdir.glob("events*"))) > 0
assert len(opt.callback_history["tb"]) == (len(opt.opt_history) - 1) // tbcb.steps_per_call
File renamed without changes.
Loading

0 comments on commit 5bf10ce

Please sign in to comment.