You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
There seems to be a bug in how eqx.Module interacts with jax.tree.map
Here is the repro:
import jax
import equinox as eqx
import jax.numpy as jnp
from flax import linen as nn
class Identity:
def __init__(self):
pass
def __call__(self, x):
return x
class IdentityEqx(eqx.Module):
def __init__(self):
pass
def __call__(self, x):
return x
class IdentityFlax(nn.Module):
def __init__(self):
pass
def __call__(self, x):
return x
# everything is fine without equinox
tst_fxns = [Identity(), Identity()]
tst_states = [jnp.zeros(2), jnp.zeros(2)]
print(jax.tree.map(lambda f,x : f(x), tst_fxns, tst_states))
# everything is fine in flax (inheritance is not a problem)
tst_fxns_flax = [IdentityFlax(), IdentityFlax()]
print(jax.tree.map(lambda f,x : f(x), tst_fxns_flax, tst_states))
# get Custom node type mismatch: expected type: <class '__main__.IdentityEqx'>, value: Array([0., 0.], dtype=float32). with equinox
tst_fxns_eqx = [IdentityEqx(), IdentityEqx()]
print(jax.tree.map(lambda f,x : f(x), tst_fxns_eqx, tst_states))
We would really like to be able to apply different functions across a list or array of inputs. We'd love to do so in equinox especially. The behavior seems like an equinox bug because we don't have this problem when we inherit from flax. Do you have any suggestions for a workaround? Do you know what the nature of this bug is and how it could be fixed?
The text was updated successfully, but these errors were encountered:
print(jax.tree.map(lambdax, f : f(x), tst_states, tst_fxns_eqx))
This isn't a bug. Equinox modules are pytrees. jax.tree.map requires that the first tree (in your case tst_fxns_eqx, with structure [IdentityEqx(), IdentityEqx()]) be a prefix of all later trees (in this case tst_states, with structure [*, *]). In particular note that your IdentityEqx class is an 'empty' pytree -- e.g. like an empty list [].
If this seems confusing, here's a simpler non-Equinox equivalent you can study, that will also raise an error:
There seems to be a bug in how eqx.Module interacts with jax.tree.map
Here is the repro:
We would really like to be able to apply different functions across a list or array of inputs. We'd love to do so in equinox especially. The behavior seems like an equinox bug because we don't have this problem when we inherit from flax. Do you have any suggestions for a workaround? Do you know what the nature of this bug is and how it could be fixed?
The text was updated successfully, but these errors were encountered: