Skip to content

Commit

Permalink
Adding fwd/bwd cast methods compatible with FP8. (#120)
Browse files Browse the repository at this point in the history
Allowing cast applying only on forward or backward passes respectively.
Making it easier to build explicit FP8 code.
  • Loading branch information
balancap authored Jun 28, 2024
1 parent c327a2f commit 1db8b33
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 68 deletions.
16 changes: 8 additions & 8 deletions examples/mnist/mnist_classifier_from_scratch_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,30 +59,30 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):


def predict(params, inputs, use_fp8=True):
reduce_precision_dtype = jsa.ops.reduce_precision_dtype if use_fp8 else lambda x, d: x
reduce_precision_dtype_grad = jsa.ops.reduce_precision_dtype_grad if use_fp8 else lambda x, d: x
reduce_precision_on_forward = jsa.ops.reduce_precision_on_forward if use_fp8 else lambda x, d: x
reduce_precision_on_backward = jsa.ops.reduce_precision_on_backward if use_fp8 else lambda x, d: x

activations = inputs
for w, b in params[:-1]:
# Forward FP8 casting.
w = reduce_precision_dtype(w, ml_dtypes.float8_e4m3fn)
activations = reduce_precision_dtype(activations, ml_dtypes.float8_e4m3fn)
w = reduce_precision_on_forward(w, ml_dtypes.float8_e4m3fn)
activations = reduce_precision_on_forward(activations, ml_dtypes.float8_e4m3fn)
# Matmul
outputs = jnp.dot(activations, w)
# Backward FP8 casting
outputs = reduce_precision_dtype_grad(outputs, ml_dtypes.float8_e5m2)
outputs = reduce_precision_on_backward(outputs, ml_dtypes.float8_e5m2)

# Bias + relu
outputs = outputs + b
activations = jnp.maximum(outputs, 0)

final_w, final_b = params[-1]
# Forward FP8 casting.
# final_w = jsa.ops.reduce_precision_dtype(final_w, ml_dtypes.float8_e4m3fn)
activations = reduce_precision_dtype(activations, ml_dtypes.float8_e4m3fn)
# final_w = jsa.ops.reduce_precision_on_forward(final_w, ml_dtypes.float8_e4m3fn)
activations = reduce_precision_on_forward(activations, ml_dtypes.float8_e4m3fn)
logits = jnp.dot(activations, final_w)
# Backward FP8 casting
logits = reduce_precision_dtype_grad(logits, ml_dtypes.float8_e5m2)
logits = reduce_precision_on_backward(logits, ml_dtypes.float8_e5m2)

logits = logits + final_b

Expand Down
4 changes: 2 additions & 2 deletions examples/scalify-quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,12 @@
"source": [
"import ml_dtypes\n",
"# Minimal FP8 simulated support is provided using jax.lax.reduce_precision and ml_dtypes.\n",
"# Similarly to `dynamic_rescale`, `reduce_precision_dtype(_grad)` are available to cast in forward and backward passes\n",
"# Similarly to `dynamic_rescale`, `reduce_precision_on_forward(_grad)` are available to cast in forward and backward passes\n",
"sc = jsa.as_scaled_array(np.array([17., 19.]), scale=np.float32(2))\n",
"\n",
"@jsa.scalify\n",
"def cast_fn(v):\n",
" return jsa.ops.reduce_precision_dtype(v, ml_dtypes.float8_e4m3fn)\n",
" return jsa.ops.reduce_precision_on_forward(v, ml_dtypes.float8_e4m3fn)\n",
"\n",
"sc_fp8 = cast_fn(sc)\n",
"print(\"Scaled input in FP32:\", sc)\n",
Expand Down
15 changes: 8 additions & 7 deletions jax_scalify/core/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
# import chex
import jax
import jax.numpy as jnp
import jaxlib
import numpy as np

# Type aliasing. To be compatible with JAX 0.3 as well.
if jax.__version_info__[1] > 3:
Array = jax.Array
ArrayTypes = (jax.Array, jax.stages.ArgInfo)
else:
Array = jaxlib.xla_extension.DeviceArray
ArrayTypes = (jaxlib.xla_extension.DeviceArray, jax.interpreters.partial_eval.DynamicJaxprTracer)
try:
from jax import Array

ArrayTypes = (Array, jax.stages.ArgInfo)
except ImportError:
from jaxlib.xla_extension import DeviceArray as Array

ArrayTypes = (Array, jax.interpreters.partial_eval.DynamicJaxprTracer)


def get_numpy_api(val: Any) -> Any:
Expand Down
7 changes: 6 additions & 1 deletion jax_scalify/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .cast import reduce_precision_dtype, reduce_precision_dtype_grad # noqa: F401
from .cast import ( # noqa: F401
cast_on_backward,
cast_on_forward,
reduce_precision_on_backward,
reduce_precision_on_forward,
)
from .debug import debug_callback, debug_callback_grad, debug_print, debug_print_grad # noqa: F401
from .rescaling import ( # noqa: F401
dynamic_rescale_l1,
Expand Down
26 changes: 21 additions & 5 deletions jax_scalify/ops/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from jax_scalify.core import Array, DTypeLike

from .rescaling import fn_bwd_identity_fwd, fn_fwd_identity_bwd
from .utils import map_on_backward, map_on_forward


def reduce_precision_dtype_base(arr: Array, dtype: DTypeLike) -> Array:
Expand All @@ -15,11 +15,27 @@ def reduce_precision_dtype_base(arr: Array, dtype: DTypeLike) -> Array:
return jax.lax.reduce_precision(arr, exponent_bits=info.nexp, mantissa_bits=info.nmant)


def reduce_precision_dtype(arr: Array, dtype: DTypeLike) -> Array:
def reduce_precision_on_forward(arr: Array, dtype: DTypeLike) -> Array:
"""`Fake` cast to an ML dtype, on the forward pass (no-op on backward pass)."""
return partial(fn_fwd_identity_bwd, lambda v: reduce_precision_dtype_base(v, dtype))(arr)
return partial(map_on_forward, lambda v: reduce_precision_dtype_base(v, dtype))(arr)


def reduce_precision_dtype_grad(arr: Array, dtype: DTypeLike) -> Array:
def reduce_precision_on_backward(arr: Array, dtype: DTypeLike) -> Array:
"""`Fake` cast to an ML dtype on the backward pass (no-op on forward pass)."""
return partial(fn_bwd_identity_fwd, lambda v: reduce_precision_dtype_base(v, dtype))(arr)
return partial(map_on_backward, lambda v: reduce_precision_dtype_base(v, dtype))(arr)


def cast_on_forward(arr: Array, dtype: DTypeLike) -> Array:
"""Cast input array only on the forward pass (no-op on the backward pass).
Useful for implementation `DenseGeneral` FP8 matmuls.
"""
return partial(map_on_forward, lambda v: jax.lax.convert_element_type(v, dtype))(arr)


def cast_on_backward(arr: Array, dtype: DTypeLike) -> Array:
"""Cast input array only on the backward pass (no-op on the forward pass).
Useful for implementation `DenseGeneral` FP8 matmuls.
"""
return partial(map_on_backward, lambda v: jax.lax.convert_element_type(v, dtype))(arr)
46 changes: 7 additions & 39 deletions jax_scalify/ops/rescaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,7 @@
from jax_scalify.core import ScaledArray, pow2_round, pow2_round_down
from jax_scalify.lax import get_data_scale, rebalance


@partial(jax.custom_vjp, nondiff_argnums=(0,))
def fn_fwd_identity_bwd(f, arg):
"""Function with identity bwd/grad."""
return f(arg)


def fn_fwd_identity_bwd_fwd(f, arg):
return arg, None


def fn_fwd_identity_bwd_bwd(f, _, grad):
return (grad,)


fn_fwd_identity_bwd.defvjp(fn_fwd_identity_bwd_fwd, fn_fwd_identity_bwd_bwd)


@partial(jax.custom_vjp, nondiff_argnums=(0,))
def fn_bwd_identity_fwd(f, arg):
"""Apply a function on the gradient/backward pass."""
return arg


def fn_bwd_identity_fwd_fwd(f, arg):
return arg, None


def fn_bwd_identity_fwd_bwd(f, _, grad):
return (f(grad),)


fn_bwd_identity_fwd.defvjp(fn_bwd_identity_fwd_fwd, fn_bwd_identity_fwd_bwd)
from .utils import map_on_backward, map_on_forward


def dynamic_rescale_max_base(arr: ScaledArray) -> ScaledArray:
Expand Down Expand Up @@ -97,11 +65,11 @@ def dynamic_rescale_l2_base(arr: ScaledArray) -> ScaledArray:


# Dynamic rescale on fwd arrays.
dynamic_rescale_max = partial(fn_fwd_identity_bwd, dynamic_rescale_max_base)
dynamic_rescale_l1 = partial(fn_fwd_identity_bwd, dynamic_rescale_l1_base)
dynamic_rescale_l2 = partial(fn_fwd_identity_bwd, dynamic_rescale_l2_base)
dynamic_rescale_max = partial(map_on_forward, dynamic_rescale_max_base)
dynamic_rescale_l1 = partial(map_on_forward, dynamic_rescale_l1_base)
dynamic_rescale_l2 = partial(map_on_forward, dynamic_rescale_l2_base)

# Dynamic rescale on gradients.
dynamic_rescale_max_grad = partial(fn_bwd_identity_fwd, dynamic_rescale_max_base)
dynamic_rescale_l1_grad = partial(fn_bwd_identity_fwd, dynamic_rescale_l1_base)
dynamic_rescale_l2_grad = partial(fn_bwd_identity_fwd, dynamic_rescale_l2_base)
dynamic_rescale_max_grad = partial(map_on_backward, dynamic_rescale_max_base)
dynamic_rescale_l1_grad = partial(map_on_backward, dynamic_rescale_l1_base)
dynamic_rescale_l2_grad = partial(map_on_backward, dynamic_rescale_l2_base)
38 changes: 38 additions & 0 deletions jax_scalify/ops/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
from functools import partial

import jax


@partial(jax.custom_vjp, nondiff_argnums=(0,))
def map_on_forward(f, arg):
"""Map a function on a forward pass only. No-op/identity on backward pass."""
return f(arg)


def map_on_forward_fwd(f, arg):
return arg, None


def map_on_forward_bwd(f, _, grad):
return (grad,)


map_on_forward.defvjp(map_on_forward_fwd, map_on_forward_bwd)


@partial(jax.custom_vjp, nondiff_argnums=(0,))
def map_on_backward(f, arg):
"""Map a function on the gradient/backward pass. No-op/identity on forward."""
return arg


def map_on_backward_fwd(f, arg):
return arg, None


def map_on_backward_bwd(f, _, grad):
return (f(grad),)


map_on_backward.defvjp(map_on_backward_fwd, map_on_backward_bwd)
57 changes: 51 additions & 6 deletions tests/ops/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,27 @@
from functools import partial

import chex
import jax
import jax.numpy as jnp
import ml_dtypes
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized
from numpy.typing import NDArray

from jax_scalify.core import scaled_array, scalify
from jax_scalify.ops import reduce_precision_dtype
from jax_scalify.ops import cast_on_backward, cast_on_forward, reduce_precision_on_forward


class CastMLDtypeTests(chex.TestCase):
class ReducePrecisionDtypeTests(chex.TestCase):
@parameterized.parameters(
{"ml_dtype": ml_dtypes.float8_e4m3fn},
{"ml_dtype": ml_dtypes.float8_e5m2},
)
def test__reduce_precision_dtype__consistent_rounding_down(self, ml_dtype):
def test__reduce_precision_on_forward__consistent_rounding_down(self, ml_dtype):
# Values potentially "problematic" in FP8.
values: NDArray[np.float16] = np.array([17, -17, 8, 1, 9, 11, 18], np.float16)
out = reduce_precision_dtype(values, dtype=ml_dtype)
out = reduce_precision_on_forward(values, dtype=ml_dtype)
expected_out = values.astype(ml_dtype)
assert out.dtype == values.dtype
npt.assert_array_equal(out, expected_out)
Expand All @@ -29,10 +31,53 @@ def test__reduce_precision_dtype__consistent_rounding_down(self, ml_dtype):
{"ml_dtype": ml_dtypes.float8_e4m3fn},
{"ml_dtype": ml_dtypes.float8_e5m2},
)
def test__reduce_precision_dtype__scalify_compatiblity(self, ml_dtype):
def test__reduce_precision_on_forward__scalify_compatiblity(self, ml_dtype):
values: NDArray[np.float16] = np.array([17, -17, 8, 1, 9, 11, 18], np.float16)
arr = scaled_array(values, np.float32(1))
out = scalify(partial(reduce_precision_dtype, dtype=ml_dtype))(arr)
out = scalify(partial(reduce_precision_on_forward, dtype=ml_dtype))(arr)

npt.assert_array_equal(out.scale, arr.scale)
npt.assert_array_equal(out, np.asarray(arr.data).astype(ml_dtype))


class CastOnForwardBackwardTests(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)
@parameterized.parameters(
{"dtype": jnp.float16},
# TODO: uncomment when JAX 0.4+ used
# {"dtype": jnp.float8_e4m3fn},
# {"dtype": jnp.float8_e5m2},
)
def test__cast_on_forward_backward__proper_results(self, dtype):
# Values potentially "problematic" in FP8.
values: NDArray[np.float16] = np.array([17, -17, 8, 1, 9, 11, 18], np.float16)
out_on_fwd = self.variant(partial(cast_on_forward, dtype=dtype))(values)
out_on_bwd = self.variant(partial(cast_on_backward, dtype=dtype))(values)

assert out_on_fwd.dtype == dtype
assert out_on_bwd.dtype == values.dtype
npt.assert_array_equal(out_on_fwd, jax.lax.convert_element_type(values, dtype))
npt.assert_array_equal(out_on_bwd, values)

@parameterized.parameters(
{"dtype": jnp.float16},
# TODO: uncomment when JAX 0.4+ used
# {"dtype": jnp.float8_e4m3fn},
# {"dtype": jnp.float8_e5m2},
)
def test__cast_on_backward__grad__proper_results(self, dtype):
def fn(val, with_cast):
if with_cast:
val = cast_on_backward(val, dtype=dtype)
val = val * val
return jax.lax.reduce_sum_p.bind(val, axes=(0,))

# Values potentially "problematic" in FP8.
values: NDArray[np.float32] = np.array([17, -17, 8, 1, 9, 11, 18], np.float32)
# Backward pass => gradient.
grads = jax.grad(partial(fn, with_cast=True))(values)
grads_ref = jax.grad(partial(fn, with_cast=False))(values)

assert grads.dtype == dtype
assert grads_ref.dtype == values.dtype
npt.assert_array_equal(grads, jax.lax.convert_element_type(grads_ref, dtype))

0 comments on commit 1db8b33

Please sign in to comment.