Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in how eqx.Module interacts with jax.tree.map #889

Open
xaviergonzalez opened this issue Nov 1, 2024 · 2 comments
Open

Bug in how eqx.Module interacts with jax.tree.map #889

xaviergonzalez opened this issue Nov 1, 2024 · 2 comments
Labels
question User queries

Comments

@xaviergonzalez
Copy link

xaviergonzalez commented Nov 1, 2024

There seems to be a bug in how eqx.Module interacts with jax.tree.map

Here is the repro:

import jax
import equinox as eqx
import jax.numpy as jnp
from flax import linen as nn

class Identity:

    def __init__(self):
      pass

    def __call__(self, x):
        return x

class IdentityEqx(eqx.Module):

    def __init__(self):
      pass

    def __call__(self, x):
        return x

class IdentityFlax(nn.Module):

    def __init__(self):
      pass

    def __call__(self, x):
        return x
       
# everything is fine without equinox
tst_fxns = [Identity(), Identity()]
tst_states = [jnp.zeros(2), jnp.zeros(2)]
print(jax.tree.map(lambda f,x : f(x), tst_fxns, tst_states))

# everything is fine in flax (inheritance is not a problem)
tst_fxns_flax = [IdentityFlax(), IdentityFlax()]
print(jax.tree.map(lambda f,x : f(x), tst_fxns_flax, tst_states))

# get Custom node type mismatch: expected type: <class '__main__.IdentityEqx'>, value: Array([0., 0.], dtype=float32). with equinox
tst_fxns_eqx = [IdentityEqx(), IdentityEqx()]
print(jax.tree.map(lambda f,x : f(x), tst_fxns_eqx, tst_states))

We would really like to be able to apply different functions across a list or array of inputs. We'd love to do so in equinox especially. The behavior seems like an equinox bug because we don't have this problem when we inherit from flax. Do you have any suggestions for a workaround? Do you know what the nature of this bug is and how it could be fixed?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Nov 1, 2024

Swap your final line for:

print(jax.tree.map(lambda x, f : f(x), tst_states, tst_fxns_eqx))

This isn't a bug. Equinox modules are pytrees. jax.tree.map requires that the first tree (in your case tst_fxns_eqx, with structure [IdentityEqx(), IdentityEqx()]) be a prefix of all later trees (in this case tst_states, with structure [*, *]). In particular note that your IdentityEqx class is an 'empty' pytree -- e.g. like an empty list [].

If this seems confusing, here's a simpler non-Equinox equivalent you can study, that will also raise an error:

fn = lambda a, b: None
tree1 = [[], []]
tree2 = ['leaf1', 'leaf2']
jax.tree.map(fn, tree1, tree2)

@patrick-kidger patrick-kidger added the question User queries label Nov 1, 2024
@xaviergonzalez
Copy link
Author

Thank you so much for your very helpful and fast reply!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants