Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add exact spectral norm feature #831

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 90 additions & 26 deletions equinox/nn/_spectral_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
import jax.random as jr
from jaxtyping import Array, Float, PRNGKeyArray

from .._eval_shape import filter_eval_shape
from .._module import field
from .._tree import tree_at
from ._sequential import StatefulLayer
from ._stateful import State, StateIndex


def _power_iteration(weight, u, v, eps):
u = weight @ v
u_norm = jnp.sqrt(jnp.sum(u**2))
u = u / jnp.maximum(eps, u_norm)

v = weight.T @ u
def _power_iteration(forward, transpose, v_prev, eps):
_, tangents_out = jax.jvp(forward, (v_prev,), (v_prev,))
u_norm = jnp.sqrt(jnp.sum(tangents_out**2))
u = tangents_out / jnp.maximum(eps, u_norm)
_, v = jax.jvp(lambda x: transpose(x)[0], (u,), (u,))
v_norm = jnp.sqrt(jnp.sum(v**2))
v = v / jnp.maximum(eps, v_norm)

Expand All @@ -42,6 +42,12 @@ class SpectralNorm(StatefulLayer, Generic[_Layer], strict=True):
[Spectral Normalization for Generative Adversarial Networks](https://arxiv.org/abs/1802.05957)
for more details and motivation.

Default approaches to spectral normalization rely on inaccurate approximations to the
spectral norm, although it often perform better; see
[Why Spectral Normalization Stabilizes GANs: Analysis and Improvements](https://arxiv.org/abs/2009.02773),
and [Generalizable Adversarial Training via Spectral Normalization](https://arxiv.org/abs/1811.07457).
Equinox offers functionality for both exact and approximate spectral norms.

!!! example

See [this example](../../examples/stateful.ipynb) for example usage.
Expand All @@ -53,6 +59,7 @@ class SpectralNorm(StatefulLayer, Generic[_Layer], strict=True):
""" # noqa: E501

layer: _Layer
exact: bool
weight_name: str = field(static=True)
uv_index: StateIndex[tuple[Float[Array, " u_size"], Float[Array, " v_size"]]]
num_power_iterations: int = field(static=True)
Expand All @@ -66,6 +73,8 @@ def __init__(
num_power_iterations: int = 1,
eps: float = 1e-12,
inference: bool = False,
exact: bool = False,
input_shape: Optional[jax.ShapeDtypeStruct] = None,
*,
key: PRNGKeyArray,
):
Expand All @@ -81,6 +90,11 @@ def __init__(
- `inference`: Whether this is in inference mode, at which time no power
iterations are performed. This may be toggled with
[`equinox.nn.inference_mode`][].
- `exact`: Whether or not to compute the exact linear transpose for power series
iteration. Traditional approaches rely on reshaping >2D linear operators,
rather than doing the linear transpose in >2D.
- `input_shape`: If `exact` is true, the input structure to the layer must be
specified
- `key`: A `jax.random.PRNGKey` used to provide randomness for initialisation.
(Keyword only argument.)

Expand All @@ -90,6 +104,13 @@ def __init__(
The `dtype` of the weight array of the `layer` input is applied to all
parameters in this layer.


!!! Caution

If `exact` is true, it computes the transpose via `jax.linear_transpose` of
the layer. This includes all operations of the layer call, which means for
layers with a bias, this can result in the incorrect spectral value.
Comment on lines +108 to +112
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should follow JAX's lead here and transpose the tangent pass of jax.jvp?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I tried to implemented what I thought you meant. This also means we could remove the "weight" flag for the exact case (maybe?) since we basically "determine" the weight through the jvp?

Let me know if this is what you had in mind, or if I was totally off. Does seem like a lot of jvps tho.


"""
self.layer = layer
self.weight_name = weight_name
Expand All @@ -98,17 +119,36 @@ def __init__(
self.inference = inference

weight = getattr(layer, weight_name)
if weight.ndim < 2:
raise ValueError("`weight` must be at least two-dimensional")
weight = jnp.reshape(weight, (weight.shape[0], -1))
dtype = weight.dtype
u_len, v_len = weight.shape
ukey, vkey = jr.split(key)
u0 = jr.normal(ukey, (u_len,), dtype=dtype)
v0 = jr.normal(vkey, (v_len,), dtype=dtype)
for _ in range(15):
u0, v0 = _power_iteration(weight, u0, v0, eps)

if not callable(self.layer):
raise ValueError("`layer` must be callable.")

if exact:
if input_shape is None:
raise ValueError(
"Must specify `input_shape` to use exact spectral norm!"
)
u_shape = filter_eval_shape(self.layer, input_shape)
u0 = jr.normal(ukey, u_shape.shape, dtype=u_shape.dtype)
v0 = jr.normal(vkey, input_shape.shape, dtype=input_shape.dtype)
reverse = jax.linear_transpose(self.layer, input_shape)
for _ in range(15):
u0, v0 = _power_iteration(self.layer, reverse, v0, self.eps)
else:
if weight.ndim < 2:
raise ValueError("`weight` must be at least two-dimensional")
weight = jnp.reshape(weight, (weight.shape[0], -1))
dtype = weight.dtype
u_len, v_len = weight.shape
u0 = jr.normal(ukey, (u_len,), dtype=dtype)
v0 = jr.normal(vkey, (v_len,), dtype=dtype)
for _ in range(15):
u0, v0 = _power_iteration(
lambda y: weight @ y, lambda z: (weight.T @ z,), v0, self.eps
)
self.uv_index = StateIndex((u0, v0))
self.exact = exact

@jax.named_scope("eqx.nn.SpectralNorm")
def __call__(
Expand Down Expand Up @@ -141,17 +181,41 @@ def __call__(

u, v = state.get(self.uv_index)
weight = getattr(self.layer, self.weight_name)
weight_shape = weight.shape
weight = jnp.reshape(weight, (weight.shape[0], -1))
if inference is None:
inference = self.inference
if not inference:
stop_weight = lax.stop_gradient(weight)
for _ in range(self.num_power_iterations):
u, v = _power_iteration(stop_weight, u, v, self.eps)
state = state.set(self.uv_index, (u, v))
σ = jnp.einsum("i,ij,j->", u, weight, v)
σ_weight = jnp.reshape(weight / σ, weight_shape)
if self.exact:
if inference is None:
inference = self.inference
if not inference:
stop_weight = lax.stop_gradient(weight)
layer = tree_at(
lambda l: getattr(l, self.weight_name), self.layer, stop_weight
)
reverse = jax.linear_transpose(layer, x)
for _ in range(self.num_power_iterations):
u, v = _power_iteration(layer, reverse, v, self.eps)
state = state.set(self.uv_index, (u, v))
else:
layer = self.layer
assert callable(layer) # checked in __init__ but pyright wants it here too
_, tangents_out = jax.jvp(layer, (v,), (v,))
σ = jnp.sum(u * tangents_out)
σ_weight = weight / σ
else:
weight_shape = weight.shape
weight = jnp.reshape(weight, (weight.shape[0], -1))
if inference is None:
inference = self.inference
if not inference:
stop_weight = lax.stop_gradient(weight)
for _ in range(self.num_power_iterations):
u, v = _power_iteration(
lambda y: stop_weight @ y,
lambda z: (stop_weight.T @ z,),
v,
self.eps,
)
state = state.set(self.uv_index, (u, v))
σ = jnp.einsum("i,ij,j->", u, weight, v)
σ_weight = jnp.reshape(weight / σ, weight_shape)
layer = tree_at(lambda l: getattr(l, self.weight_name), self.layer, σ_weight)
out = layer(x)
return out, state
58 changes: 57 additions & 1 deletion tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ def λ1():
eqx.nn.Linear(5, 6, key=getkey()), "weight", key=getkey()
)
state = eqx.nn.State(spectral)
for _ in range(100):
for _ in range(200):
_, state = spectral(x, state)
assert jnp.allclose(λ1(), 1)

Expand Down Expand Up @@ -1069,6 +1069,62 @@ def λ1():
assert out.shape == (4, 6, 6, 6)


def test_spectral_norm_exact(getkey):
def λ1():
u, v = state.get(spectral.uv_index)
_, tangents_out = jax.jvp(spectral.layer, (v,), (v,))
σ = jnp.sum(u * tangents_out)
_, s, _ = jnp.linalg.svd(spectral.layer.weight / σ) # pyright: ignore
return s[0]

x = jrandom.normal(getkey(), (5,))
spectral = eqx.nn.SpectralNorm(
eqx.nn.Linear(5, 6, key=getkey(), use_bias=True),
"weight",
exact=True,
input_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
key=getkey(),
)
state = eqx.nn.State(spectral)
for _ in range(200):
_, state = spectral(x, state)
assert jnp.allclose(λ1(), 1)

# "gradient descent"
spectral = eqx.tree_at(
lambda s: s.layer.weight, spectral, spectral.layer.weight + 1
)
assert not jnp.allclose(λ1(), 1)
for _ in range(100):
_, state = spectral(x, state)
assert jnp.allclose(λ1(), 1)

# Test not updated at inference time
spectral = eqx.tree_at(
lambda s: s.layer.weight, spectral, spectral.layer.weight + 1
)
spectral = eqx.nn.inference_mode(spectral, value=True)
assert not jnp.allclose(λ1(), 1)
for _ in range(100):
_, state = spectral(x, state)
assert not jnp.allclose(λ1(), 1)

# Test >2 dimensional input

x = jrandom.normal(getkey(), (5, 8, 8, 8))
conv = eqx.nn.Conv3d(5, 4, 3, key=getkey(), use_bias=False)
spectral = eqx.nn.SpectralNorm(
conv,
"weight",
exact=True,
input_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
key=getkey(),
)
state = eqx.nn.State(spectral)
out, _ = spectral(x, state)
assert out.shape == (4, 6, 6, 6)


def test_weight_norm(getkey):
# Linear
linear = eqx.nn.Linear(4, 4, key=getkey())
Expand Down
Loading