Equinox v0.11.3
Features
- Added
equinox.nn.RMSNorm
. - Added
equinox.nn.WeightNorm
. equinox.tree_deserialise_leaves
now treatsjax.ShapeDtypeStruct
s in the same way as arrays. This makes it possible to avoid instantiating the initial model parameters only to throw them away again, by usingequinox.filter_eval_shape
:(#259)model = eqx.filter_eval_shape(Model, ...hyperparameters...) model = eqx.tree_deserialise_leaves(load_path, model)
Bugfixes
equinox.internal.noinline
no longer initialises the JAX backend on use.equinox.filter_jit(...).lower(..., some_kwarg=...)
no longer crashes (#625, #627)- The state of
equionx.nn.BatchNorm
now uses the default floating point dtype, rather than always usingfloat32
. equinox.nn.MultiheadAttention
should now perform the softmax infloat32
even when the input is of lower dtype. (This is important for numerical stability.)
Refactor
- All the layers in
equinox.nn.{Linear, MLP, ...}
now standardise on accepting extra**kwargs
and not callingsuper().__init__
. The intention is that these layers be treated as final, i.e. not subclassable. (Previously things were inconsistent: some did this and some did not.) - Should now be compatible with
JAX_NUMPY_DTYPE_PROMOTION=strict
andJAX_NUMPY_RANK_PROMOTION=raise
, and this is checked in tests. - Better error message when no kwargs passed to
filter_grad
(Thanks @knyazer! #589)
Internal features
These are undocumented internal features, that may be changed at any time.
- Added
EQX_GETKEY_SEED
for use withequinox.internal.GetKey
. equinox.internal.while_loop
now has its runtime errors removed. This should help with compatibility with TPUs. (#628)
New Contributors
- @haydn-jones made their first contribution in #608
Full Changelog: v0.11.2...v0.11.3