chore(deps): update dependency jax to >=0.4.38, <=0.4.38 #940
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR contains the following updates:
>=0.4.16, <=0.4.34
->>=0.4.38, <=0.4.38
Release Notes
jax-ml/jax (jax)
v0.4.38
Compare Source
Changes:
jax.tree.flatten_with_path
andjax.tree.map_with_path
are addedas shortcuts of the corresponding
tree_util
functions.Deprecations
jax.core
namespace have been deprecated.Most were no-ops, were little-used, or can be replaced by APIs of the same
name in {mod}
jax.extend.core
; see the documentation for {mod}jax.extend
for information on the compatibility guarantees of these semi-public extensions.
jax.core
:check_eqn
,check_type
,check_valid_jaxtype
, andnon_negative_dim
.jax.lib.xla_bridge
:xla_client
anddefault_backend
.jax.lib.xla_client
:_xla
andbfloat16
.jax.numpy
:round_
.New Features
jax.export.export
can be used for device-polymorphic export withshardings constructed with {func}
jax.sharding.AbstractMesh
.See the jax.export documentation.
jax.lax.split
. This is a primitive version of{func}
jax.numpy.split
, added because it yields a more compacttranspose during automatic differentiation.
v0.4.37
Compare Source
This is a patch release of jax 0.4.36. Only "jax" was released at this version.
jit
would error if an argument was namedf
(#25329).index out of range
error in{func}
jax.lax.while_loop
if the user register pytree node class withdifferent aux data for the flatten and flatten_with_path.
v0.4.36
Compare Source
Breaking Changes
This release lands "stackless", an internal change to JAX's tracing
machinery. We made trace dispatch purely a function of context rather than a
function of both context and data. This let us delete a lot of machinery for
managing data-dependent tracing: levels, sublevels,
post_process_call
,new_base_main
,custom_bind
, and so on. The change should only affectusers that use JAX internals.
If you do use JAX internals then you may need to
update your code (see
jax-ml/jax@c36e1f7
for clues about how to do this). There might also be version skew
issues with JAX libraries that do this. If you find this change breaks your
non-JAX-internals-using code then try the
config.jax_data_dependent_tracing_fallback
flag as a workaround, and ifyou need help updating your code then please file a bug.
{func}
jax.experimental.jax2tf.convert
withnative_serialization=False
or with
enable_xla=False
have been deprecated since July 2024, withJAX version 0.4.31. Now we removed support for these use cases.
jax2tf
with native serialization will still be supported.
In
jax.interpreters.xla
, thexb
,xc
, andxe
symbols have been removedafter being deprecated in JAX v0.4.31. Instead use
xb = jax.lib.xla_bridge
,xc = jax.lib.xla_client
, andxe = jax.lib.xla_extension
.The deprecated module
jax.experimental.export
has been removed. It was replacedby {mod}
jax.export
in JAX v0.4.30. See the migration guidefor information on migrating to the new API.
The
initial
argument to {func}jax.nn.softmax
and {func}jax.nn.log_softmax
has been removed, after being deprecated in v0.4.27.
Calling
np.asarray
on typed PRNG keys (i.e. keys produced by :func:jax.random.key
)now raises an error. Previously, this returned a scalar object array.
The following deprecated methods and functions in {mod}
jax.export
havebeen removed:
jax.export.DisabledSafetyCheck.shape_assertions
: it had no effectalready.
jax.export.Exported.lowering_platforms
: useplatforms
.jax.export.Exported.mlir_module_serialization_version
:use
calling_convention_version
.jax.export.Exported.uses_shape_polymorphism
:use
uses_global_constants
.lowering_platforms
kwarg for {func}jax.export.export
: useplatforms
instead.The kwargs
symbolic_scope
andsymbolic_constraints
from{func}
jax.export.symbolic_args_specs
have been removed. They weredeprecated in June 2024. Use
scope
andconstraints
instead.Hashing of tracers, which has been deprecated since version 0.4.30, now
results in a
TypeError
.Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
replaces previous build.py usage. Run
python build/build.py --help
formore details. Brief overview of the new subcommand options:
build
: Builds JAX wheel packages. For e.g.,python build/build.py build --wheels=jaxlib,jax-cuda-pjrt
requirements_update
: Updates requirements_lock.txt files.{func}
jax.scipy.linalg.toeplitz
now does implicit batching on multi-dimensionalinputs. To recover the previous behavior, you can call {func}
jax.numpy.ravel
on the function inputs.
{func}
jax.scipy.special.gamma
and {func}jax.scipy.special.gammasgn
nowreturn NaN for negative integer inputs, to match the behavior of Scihttps://github.com/scipy/scipy/pull/21827scipy/pull/21827.
jax.clear_backends
was removed after being deprecated in v0.4.26.We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
call that we guarantee export stability. This is because this custom call
relies on Triton IR, which is not guaranteed to be stable. If you need
to export code that uses this custom call, you can use the
disabled_checks
parameter. See more details in the documentation.
New Features
jax.jit
got a newcompiler_options: dict[str, Any]
argument, forpassing compilation options to XLA. For the moment it's undocumented and
may be in flux.
jax.tree_util.register_dataclass
now allows metadata fields to bedeclared inline via {func}
dataclasses.field
. See the function documentationfor examples.
jax.numpy.put_along_axis
.jax.lax.linalg.eig
and the relatedjax.numpy
functions({func}
jax.numpy.linalg.eig
and {func}jax.numpy.linalg.eigvals
) are nowsupported on GPU. See {jax-issue}
#24663
for more details.jax_exec_time_optimization_effort
andjax_memory_fitting_effort
, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.Bug fixes
result in an indexing overflow for batch sizes close to int32 max. See
{jax-issue}
#24843
for more details.Deprecations
jax.lib.xla_extension.ArrayImpl
andjax.lib.xla_client.ArrayImpl
are deprecated;use
jax.Array
instead.jax.lib.xla_extension.XlaRuntimeError
is deprecated; usejax.errors.JaxRuntimeError
instead.
v0.4.35
Compare Source
Breaking Changes
jax.numpy.isscalar
now returns True for any array-like object withzero dimensions. Previously it only returned True for zero-dimensional
array-like objects with a weak dtype.
jax.experimental.host_callback
has been deprecated since March 2024, withJAX version 0.4.26. Now we removed it.
See {jax-issue}
#20385
for a discussion of alternatives.Changes:
jax.lax.FftType
was introduced as a public name for the enum of FFToperations. The semi-public API
jax.lib.xla_client.FftType
has beendeprecated.
libtpu
package rather thanlibtpu-nightly
. For the next few releases JAX will pin an empty version oflibtpu-nightly
as well aslibtpu
to ease the transition; that dependencywill be removed in Q1 2025.
Deprecations:
jax.lib.xla_client.PaddingType
has been deprecated.No JAX APIs consume this type, so there is no replacement.
jax.pure_callback
and{func}
jax.extend.ffi.ffi_call
undervmap
has been deprecated and so hasthe
vectorized
parameter to those functions. Thevmap_method
parametershould be used instead for better defined behavior. See the discussion in
{jax-issue}
#23881
for more details.jax.lib.xla_client.register_custom_call_target
hasbeen deprecated. Use the JAX FFI instead.
jax.lib.xla_client.dtype_to_etype
,jax.lib.xla_client.ops
,jax.lib.xla_client.shape_from_pyval
,jax.lib.xla_client.PrimitiveType
,jax.lib.xla_client.Shape
,jax.lib.xla_client.XlaBuilder
, andjax.lib.xla_client.XlaComputation
have been deprecated. Use StableHLOinstead.
Configuration
📅 Schedule: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined).
🚦 Automerge: Disabled by config. Please merge this manually once you are satisfied.
♻ Rebasing: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.
🔕 Ignore: Close this PR and you won't be reminded about this update again.
This PR was generated by Mend Renovate. View the repository job log.