Skip to content

Equinox v0.11.3

Compare
Choose a tag to compare
@github-actions github-actions released this 10 Jan 21:26
· 110 commits to main since this release

Features

  • Added equinox.nn.RMSNorm.
  • Added equinox.nn.WeightNorm.
  • equinox.tree_deserialise_leaves now treats jax.ShapeDtypeStructs in the same way as arrays. This makes it possible to avoid instantiating the initial model parameters only to throw them away again, by using equinox.filter_eval_shape:
    model = eqx.filter_eval_shape(Model, ...hyperparameters...)
    model = eqx.tree_deserialise_leaves(load_path, model)
    (#259)

Bugfixes

  • equinox.internal.noinline no longer initialises the JAX backend on use.
  • equinox.filter_jit(...).lower(..., some_kwarg=...) no longer crashes (#625, #627)
  • The state of equionx.nn.BatchNorm now uses the default floating point dtype, rather than always using float32.
  • equinox.nn.MultiheadAttention should now perform the softmax in float32 even when the input is of lower dtype. (This is important for numerical stability.)

Refactor

  • All the layers in equinox.nn.{Linear, MLP, ...} now standardise on accepting extra **kwargs and not calling super().__init__. The intention is that these layers be treated as final, i.e. not subclassable. (Previously things were inconsistent: some did this and some did not.)
  • Should now be compatible with JAX_NUMPY_DTYPE_PROMOTION=strict and JAX_NUMPY_RANK_PROMOTION=raise, and this is checked in tests.
  • Better error message when no kwargs passed to filter_grad (Thanks @knyazer! #589)

Internal features

These are undocumented internal features, that may be changed at any time.

  • Added EQX_GETKEY_SEED for use with equinox.internal.GetKey.
  • equinox.internal.while_loop now has its runtime errors removed. This should help with compatibility with TPUs. (#628)

New Contributors

Full Changelog: v0.11.2...v0.11.3