diff --git a/experiments/mnist/mnist_classifier_from_scratch.py b/experiments/mnist/mnist_classifier_from_scratch.py index 5381951..ccf4268 100644 --- a/experiments/mnist/mnist_classifier_from_scratch.py +++ b/experiments/mnist/mnist_classifier_from_scratch.py @@ -104,7 +104,7 @@ def data_stream(): params = init_random_params(param_scale, layer_sizes) # Transform parameters to `ScaledArray` and proper dtype. params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale)) - params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) + params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) @jit @jsa.autoscale @@ -118,7 +118,7 @@ def update(params, batch): batch = next(batches) # Scaled micro-batch + training dtype cast. batch = jsa.as_scaled_array(batch, scale=scale_dtype(1)) - batch = jax.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) + batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): params = update(params, batch) diff --git a/experiments/mnist/mnist_classifier_from_scratch_fp8.py b/experiments/mnist/mnist_classifier_from_scratch_fp8.py index ff0c327..4a84055 100644 --- a/experiments/mnist/mnist_classifier_from_scratch_fp8.py +++ b/experiments/mnist/mnist_classifier_from_scratch_fp8.py @@ -131,7 +131,7 @@ def data_stream(): params = init_random_params(param_scale, layer_sizes) # Transform parameters to `ScaledArray` and proper dtype. params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale)) - params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) + params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) @jit @jsa.autoscale @@ -145,7 +145,7 @@ def update(params, batch): batch = next(batches) # Scaled micro-batch + training dtype cast. batch = jsa.as_scaled_array(batch, scale=scale_dtype(1)) - batch = jax.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) + batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): params = update(params, batch) diff --git a/experiments/mnist/optax_cifar_training.py b/experiments/mnist/optax_cifar_training.py index dab0e59..eedb8c7 100644 --- a/experiments/mnist/optax_cifar_training.py +++ b/experiments/mnist/optax_cifar_training.py @@ -118,7 +118,7 @@ def data_stream(): batches = data_stream() params = init_random_params(param_scale, layer_sizes) - params = jax.tree_map(lambda v: v.astype(training_dtype), params) + params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params) # Transform parameters to `ScaledArray` and proper dtype. optimizer = optax.adam(learning_rate=lr, eps=1e-5) opt_state = optimizer.init(params) @@ -126,7 +126,7 @@ def data_stream(): if use_autoscale: params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale)) - params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) + params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) @jit @autoscale @@ -143,7 +143,7 @@ def update(params, batch, opt_state): # Scaled micro-batch + training dtype cast. if use_autoscale: batch = jsa.as_scaled_array(batch, scale=scale_dtype(param_scale)) - batch = jax.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) + batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): params, opt_state = update(params, batch, opt_state) diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index bd2e056..7dd3862 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -216,7 +216,7 @@ def as_scaled_array(val: Any, scale: Optional[ArrayLike] = None) -> ScaledArray: Returns: Scaled array instance. """ - return jax.tree_map(lambda x: as_scaled_array_base(x, scale), val, is_leaf=is_scaled_leaf) + return jax.tree_util.tree_map(lambda x: as_scaled_array_base(x, scale), val, is_leaf=is_scaled_leaf) def asarray_base(val: Any, dtype: DTypeLike = None) -> GenericArray: @@ -239,7 +239,7 @@ def asarray(val: Any, dtype: DTypeLike = None) -> GenericArray: Args: dtype: Optional dtype of the final array. """ - return jax.tree_map(lambda x: asarray_base(x, dtype), val, is_leaf=is_scaled_leaf) + return jax.tree_util.tree_map(lambda x: asarray_base(x, dtype), val, is_leaf=is_scaled_leaf) def is_numpy_scalar_or_array(val): diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index acf15ce..4fc1050 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -313,7 +313,7 @@ def wrapped(*args, **kwargs): if len(kwargs) > 0: raise NotImplementedError("`autoscale` JAX interpreter not supporting named tensors at present.") - aval_args = jax.tree_map(_get_aval, args, is_leaf=is_scaled_leaf) + aval_args = jax.tree_util.tree_map(_get_aval, args, is_leaf=is_scaled_leaf) # Get jaxpr of unscaled/normal graph. Getting output Pytree shape as well. closed_jaxpr, outshape = jax.make_jaxpr(fun, return_shape=True)(*aval_args, **kwargs) out_leaves, out_pytree = jax.tree_util.tree_flatten(outshape) diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index abf9006..e83dda1 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -242,7 +242,7 @@ def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn, scaled_fn = self.variant(autoscale(fn)) scaled_output = scaled_fn(*inputs) # Normal JAX path, without scaled arrays. - raw_inputs = jax.tree_map(np.asarray, inputs, is_leaf=is_scaled_leaf) + raw_inputs = jax.tree_util.tree_map(np.asarray, inputs, is_leaf=is_scaled_leaf) expected_output = self.variant(fn)(*raw_inputs) # Do we re-construct properly the output type (i.e. handling Pytree properly)? diff --git a/tests/core/test_pow2.py b/tests/core/test_pow2.py index c5a48c5..7a1fdb6 100644 --- a/tests/core/test_pow2.py +++ b/tests/core/test_pow2.py @@ -132,7 +132,6 @@ def test__get_mantissa__proper_value__multi_dtypes(self, val_mant, dtype): assert val_mant.dtype == val.dtype assert val_mant.shape == () assert type(val_mant) in {type(val), np.ndarray} - print(mant, val_mant, dtype) npt.assert_equal(val_mant, mant) # Should be consistent with `pow2_round_down`. bitwise, not approximation. npt.assert_equal(mant * pow2_round_down(val), val)