diff --git a/docs/source/JAX_for_PyTorch_users.ipynb b/docs/source/JAX_for_PyTorch_users.ipynb index adfb32b..ccd97ad 100644 --- a/docs/source/JAX_for_PyTorch_users.ipynb +++ b/docs/source/JAX_for_PyTorch_users.ipynb @@ -10,15 +10,18 @@ "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_PyTorch_users.ipynb)\n", "\n", - "This is a quick overview of JAX and the JAX AI stack written for those who are famiilar with PyTorch.\n", + "This tutorial provides a quick overview of JAX and JAX-based libraries (the JAX AI stack) for PyTorch users. You will how to:\n", "\n", - "First, we cover how to manipulate JAX Arrays following the [well-known PyTorch's tensors tutorial](https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html). Next, we explore automatic differentiation with JAX, followed by how to build a model and optimize its parameters.\n", - "Finally, we will introduce `jax.jit` and compare it to its PyTorch counterpart `torchscript`.\n", + "- Manipulate [`jax.Array`s](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array), following PyTorch's [Tensors](https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html) tutorial.\n", + "- Explore automatic differentiation with JAX with [`jax.grad`](https://jax.readthedocs.io/en/latest/automatic-differentiation.html), following the [autodiff PyTorch tutorial](https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html).\n", + "- Build a neural network and optimize its parameters with [Flax](https://flax.readthedocs.io/en/latest/) and [Optax](https://optax.readthedocs.io/en/latest/).\n", + "- Use [`jax.jit`](https://jax.readthedocs.io/en/latest/jit-compilation.html) for just-in-time compilation, comparing it PyTorch's `torchscript`.\n", "\n", "## Setup\n", "\n", - "Let's get started by importing JAX and checking the installed version.\n", - "For details on how to install JAX check [installation guide](https://jax.readthedocs.io/en/latest/installation.html)." + "JAX installation is covered in [this guide](https://jax.readthedocs.io/en/latest/installation.html) on the JAX documentation site.\n", + "\n", + "Import JAX and JAX NumPy, and check the installed version." ] }, { @@ -52,17 +55,17 @@ "id": "LNBvB_hRDteB" }, "source": [ - "## JAX Arrays manipulation\n", + "## `jax.Array` manipulation vs `torch.Tensor`s\n", + "\n", + "This section covers [`jax.Array`s](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html) - JAX's primary array object - and how to manipulate them. `jax.Array` is the JAX counterpart of [PyTorch's `torch.Tensor`](https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html).\n", "\n", - "In this section, we will learn about JAX Arrays and how to manipulate them compared to PyTorch tensors.\n", + "### Initializing a `jax.Array`\n", "\n", - "### Initializing a JAX Array\n", + "Similar to `torch.Tensor`s, [`jax.Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html) objects are never constructed directly. Instead, they are constructed via array creation APIs that populate the new array with constant numbers, random numbers, or data drawn from Python lists, NumPy arrays, `torch.Tensor`s, and so on.\n", "\n", - "The primary array object in JAX is the `jax.Array`, which is the JAX counterpart of `torch.Tensor`.\n", - "As with `torch.Tensor`, `jax.Array` objects are never constructed directly, but rather constructed via array creation APIs that populate the new array with constant numbers, random numbers, or data drawn from lists, numpy arrays, torch tensors, and more.\n", - "Let's see some examples of this.\n", + "Let's go through some examples.\n", "\n", - "To initialize an array from a Python data:" + "To initialize an array from Python data:" ] }, { @@ -86,7 +89,7 @@ } ], "source": [ - "# From data\n", + "# From data:\n", "data = [[1, 2, 3], [3, 4, 5]]\n", "x_array = jnp.array(data)\n", "assert isinstance(x_array, jax.Array)\n", @@ -97,7 +100,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Or from an existing NumPy array:" + "To initialize from an existing NumPy array:" ] }, { @@ -127,14 +130,14 @@ "x_np = jnp.array(np_array)\n", "assert isinstance(x_np, jax.Array)\n", "print(x_np, x_np.shape, x_np.dtype)\n", - "# x_np is a copy of np_array" + "# x_np is a copy of np_array." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "You can create arrays with the same shape and `dtype` as existing JAX Arrays:" + "We can create arrays with the same shape and `dtype` as existing `jax.Array`s:" ] }, { @@ -171,7 +174,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You can even initialize arrays with constants or random values. For example:" + "We can also initialize arrays with constants or random values. For example:" ] }, { @@ -221,8 +224,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "JAX avoids implicit global random state and instead tracks state explicitly via a random `key`.\n", - "If we create two random arrays using the same `key` we will obtain two identical random arrays.\n", + "JAX avoids implicit global random state and instead [tracks state explicitly via a random `key`](https://jax.readthedocs.io/en/latest/random-numbers.html#explicit-random-state).\n", + "If we create two random arrays using the same `key`, we will obtain two identical random arrays.\n", "We can also split the random `key` into multiple keys to create two different random arrays." ] }, @@ -248,14 +251,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "For further discussion on random numbers in NumPy and JAX check [this tutorial](https://jax.readthedocs.io/en/latest/random-numbers.html)." + "**Note:** Learn more about [`jax.random`](https://jax.readthedocs.io/en/latest/jax.random.html#module-jax.random) and pseudorandom number generation (PRNG) in JAX in [this tutorial](https://jax.readthedocs.io/en/latest/random-numbers.html) on the JAX documentation site." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Finally, if you have a PyTorch tensor, you can use it to initialize a JAX Array:" + "And if you have a PyTorch `tensor`, you can also use it to initialize a `jax.Array`:" ] }, { @@ -287,12 +290,13 @@ "\n", "x_torch = torch.rand(3, 4)\n", "\n", - "# Create JAX Array as a copy of x_torch tensor\n", + "# Create a `jax.Array` as a copy of the `x_torch` tensor\n", + "# using `jax.numpy.asarray()`.\n", "x_jax = jnp.asarray(x_torch)\n", "assert isinstance(x_jax, jax.Array)\n", "print(x_jax, x_jax.shape, x_jax.dtype)\n", "\n", - "# Use dlpack to create JAX Array without copying\n", + "# Use `jax.dlpack.from_dlpack()` to create a `jax.Array` without copying.\n", "x_jax = jax.dlpack.from_dlpack(x_torch.to(device=\"cuda\"), copy=False)\n", "print(x_jax, x_jax.shape, x_jax.dtype)" ] @@ -303,10 +307,10 @@ "id": "oTXSGITNNnuY" }, "source": [ - "### Attributes of a JAX Array\n", + "### Attributes of a `jax.Array`\n", "\n", "\n", - "Similarly to PyTorch tensors, JAX Array attributes describe the array's shape, dtype and device:" + "Similar to `torch.Tensor`s, `jax.Array` attributes describe the array's shape, dtype and device:" ] }, { @@ -343,10 +347,10 @@ "id": "S4CxmHSaKz-r" }, "source": [ - "However, there are some notable differences between PyTorch tensors and JAX Arrays:\n", - "- JAX Arrays are immutable\n", - "- The default integer and float dtypes are int32 and float32\n", - "- The default device corresponds to the available accelerator, e.g. cuda:0 if one or multiple GPUs are available." + "However, there are some notable differences between `torch.Tensor`s and `jax.Array`s:\n", + "- `jax.Array`s are immutable.\n", + "- The default integer and float dtypes are `int32` and `float32`.\n", + "- The default device corresponds to the available accelerator, e.g. `cuda:0` if one or multiple GPUs are available." ] }, { @@ -394,7 +398,7 @@ "id": "1o1u2N1lL7I3" }, "source": [ - "For some discussion of JAX's alternative to in-place mutation, refer to https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html." + "To learn more about JAX's alternative to in-place mutation, refer to the [`jax.numpy.ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) API documentation on JAX's site." ] }, { @@ -406,7 +410,7 @@ "source": [ "### Devices and accelerators\n", "\n", - "Using the PyTorch API, we can check whether we have GPU accelerators available with `torch.cuda.is_available()`. In JAX, we can check available devices as follows:" + "Using the PyTorch API, we can check whether we have GPU accelerators available with `torch.cuda.is_available()`. In JAX, we can check available devices using [`jax.devices`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices):" ] }, { @@ -431,7 +435,7 @@ ], "source": [ "print(f\"Available devices given a backend (gpu or tpu or cpu): {jax.devices()}\")\n", - "# Define CPU and CUDA devices\n", + "# Define CPU and CUDA devices.\n", "cpu_device = jax.devices(\"cpu\")[0]\n", "cuda_device = jax.devices(\"cuda\")[0]\n", "print(cpu_device, cuda_device)" @@ -467,10 +471,10 @@ } ], "source": [ - "# create an array on CPU and check the device\n", + "# Create an array on CPU and check the device.\n", "x_cpu = jnp.ones((3, 4), device=cpu_device)\n", "print(x_cpu.device, )\n", - "# create an array on GPU\n", + "# Create an array on GPU.\n", "x_gpu = jnp.ones((3, 4), device=cuda_device)\n", "print(x_gpu.device)" ] @@ -479,7 +483,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In PyTorch we are used to device placement always being explicit. JAX can operate this way via explicit device placement as above, but unless the device is specified the array will remain *uncommitted*: i.e. it will be stored on the default device, but allow implicit movement to other devices when necessary:" + "In PyTorch device placement is explicit. JAX can operate this way via _explicit device placement_ as above. But, unless the device is specified, the array will remain *uncommitted*, which means it will be stored on the _default device_, while allowing implicit movement to other devices when necessary:" ] }, { @@ -508,7 +512,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "However, if we make a computation with two arrays with explicitly specified devices, e.g. CPU and CUDA, similarly to PyTorch, an error will be raised." + "However, if we make a computation with two arrays with explicitly specified devices, such as CPU and CUDA, similar to PyTorch, an error will be raised." ] }, { @@ -527,7 +531,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To move from one device to another, we can use `jax.device_put` function:" + "To move from one `device` to another, we can use the [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html#jax.device_put) function:" ] }, { @@ -556,7 +560,7 @@ "id": "o2481PfoPdFG" }, "source": [ - "### Operations on JAX Arrays\n", + "### Operations on `jax.Array`s\n", "\n", "There is a large list of operations (arithmetics, linear algebra, matrix manipulation, etc) that can be directly performed on JAX Arrays. JAX API contains important modules:\n", "- `jax.numpy` provides NumPy-like functions\n", @@ -567,7 +571,7 @@ "\n", "More details on available ops can be found in the [API reference](https://jax.readthedocs.io/en/latest/jax.html).\n", "\n", - "All operations can be run on CPUs, GPUs or TPUs. By default, JAX Arrays are created on an accelerated device, while PyTorch tensors are created on CPUs.\n", + "All operations can be run on CPUs, GPUs or TPUs. By default, `jaxArray`s are created on an accelerated device, while PyTorch `torch.Tensor`s are created on CPUs.\n", "\n", "We can now try out some array operations and check for similarities between the JAX, NumPy and PyTorch APIs." ] @@ -609,7 +613,7 @@ "print(f\"First column: {tensor[:, 0]}\")\n", "print(f\"Last column: {tensor[..., -1]}\")\n", "\n", - "# Equivalent PyTorch op: tensor[:, 1] = 0\n", + "# Equivalent PyTorch op: `tensor[:, 1] = 0`.\n", "tensor = tensor.at[:, 1].set(0)\n", "\n", "print(tensor)" @@ -619,7 +623,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We would like to note particular out-of-bounds indexing behaviour in JAX. In JAX the index is clamped to the bounds of the array in the indexing operations." + "Note the particular out-of-bounds indexing behaviour in JAX: In JAX the index is clamped to the bounds of the array in the indexing operations." ] }, { @@ -643,7 +647,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Join arrays similar to `torch.cat`. Note the kwarg name: `axis` vs `dim`." + "**Join arrays similar to `torch.cat`:** Note the kwarg name: `axis` vs `dim`." ] }, { @@ -676,7 +680,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Arithmetic operations. Operations below compute the matrix multiplication between two tensors. y1, y2 will have the same value." + "**Arithmetic operations:** Operations below compute the matrix multiplication between two tensors. `y1`, `y2` will have the same value:" ] }, { @@ -687,13 +691,13 @@ }, "outputs": [], "source": [ - "# ``tensor.T`` returns the transpose of a tensor\n", + "# `tensor.T` returns the transpose of a tensor.\n", "y1 = tensor @ tensor.T\n", "y2 = jnp.matmul(tensor, tensor.T)\n", "\n", "assert (y1 == y2).all()\n", "\n", - "# This computes the element-wise product. z1, z2 will have the same value\n", + "# This computes the element-wise product. `z1`, `z2` will have the same value.\n", "z1 = tensor * tensor\n", "z2 = jnp.multiply(tensor, tensor)\n", "\n", @@ -704,7 +708,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Single-element arrays. If you have a one-element array, for example by aggregating all values of a tensor into one value, you can convert it to a Python numerical value using `.item()`:" + "**Single-element arrays:** If you have a single/one-element array, for example by aggregating all values of a tensor into one value, you can convert it to a Python numerical value using `.item()`:" ] }, { @@ -759,11 +763,11 @@ "source": [ "## Automatic differentiation with JAX\n", "\n", - "In this section, we will learn about the fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general autodiff system, and its API has inspired the `torch.func` module in PyTorch, previously known as “functorch” (JAX-like composable function transforms for PyTorch).\n", + "This section covers the basics of the [automatic differentiation (autodiff)](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) in JAX. JAX has a pretty general autodiff system, and its API has inspired the `torch.func` module in PyTorch, previously known as “functorch” (JAX-like composable function transforms for PyTorch).\n", "\n", - "In PyTorch, there is an API to turn on the automatic operations graph recording (e.g., `required_grad` argument and `tensor.backward()`), but in JAX, automatic differentiation is a functional operation, i.e., there is no need to mark arrays with a flag to enable gradient tracking.\n", + "In PyTorch, there is an API to turn on the automatic operations graph recording (for example, the `required_grad` argument and `tensor.backward()`). In JAX, automatic differentiation is a functional operation, which means there is no need to mark arrays with a flag to enable gradient tracking.\n", "\n", - "Let us follow [autodiff PyTorch tutorial](https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html) and consider the simplest one-layer neural network, with input `x`, parameters `w` and `b`, and some loss function. In JAX, this can be defined in the following way:" + "Let's follow PyTorch's [Automatic differentiation with `torch.autograd`](https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html) tutorial and consider the simplest one-layer neural network, with input `x`, parameters `w` and `b`, and some loss function. In JAX, this can be defined in the following way:" ] }, { @@ -790,12 +794,12 @@ "import jax.numpy as jnp\n", "\n", "\n", - "# Input tensor\n", + "# Input tensor.\n", "x = jnp.ones(5)\n", - "# Target output\n", + "# Target output.\n", "y_true = jnp.zeros(3)\n", "\n", - "# Initialize random parameters\n", + "# Initialize random parameters.\n", "seed = 123\n", "key = jax.random.key(seed)\n", "key, w_key, b_key = jax.random.split(key, 3)\n", @@ -803,11 +807,11 @@ "b = jax.random.normal(b_key, (3, ))\n", "\n", "\n", - "# model function\n", + "# Model function.\n", "def predict(x, w, b):\n", " return jnp.matmul(x, w) + b\n", "\n", - "# Criterion or loss function\n", + "# Criterion or loss function.\n", "def compute_loss(w, b, x, y_true):\n", " y_pred = predict(x, w, b)\n", " return jnp.mean((y_true - y_pred) ** 2)\n", @@ -823,7 +827,7 @@ "id": "k3r8od6LtD_9" }, "source": [ - "In our example network, `w` and `b` are parameters to optimize and we need to be able to compute the gradients of the loss function with respect to those variables. In order to do that, we use [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) function on `compute_loss` function:" + "In our example network, `w` and `b` are parameters to optimize, and we need to be able to compute the gradients of the loss function with respect to those variables. To do that, we can use the [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) transformation on the `compute_loss` function as follows:" ] }, { @@ -851,7 +855,7 @@ } ], "source": [ - "# Differentiate `compute_loss` with respect to the 0 and 1 positional arguments:\n", + "# Differentiate `compute_loss` with respect to the `0` and `1` positional arguments.\n", "w_grad, b_grad = jax.grad(compute_loss, argnums=(0, 1))(w, b, x, y_true)\n", "print(f'{w_grad=}')\n", "print(f'{b_grad=}')" @@ -878,7 +882,7 @@ } ], "source": [ - "# Compute w_grad, b_grad and loss value:\n", + "# Compute `w_grad`, `b_grad` and `loss_value`.\n", "loss_value, (w_grad, b_grad) = jax.value_and_grad(compute_loss, argnums=(0, 1))(w, b, x, y_true)\n", "print(f'{w_grad=}')\n", "print(f'{b_grad=}')\n", @@ -890,10 +894,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### `jax.grad` and PyTrees\n", + "### `jax.grad` and pytrees\n", + "\n", "\n", + "JAX uses the [pytree abstraction](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#working-with-pytrees), such as Python containers like dicts, tuples, lists, etc which provides a uniform system for handling nested containers of array values. And JAX's functional API works easily on these containers.\n", "\n", - "JAX introduced the [PyTree abstraction](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#working-with-pytrees)(e.g. Python containers like dicts, tuples, lists, etc which provides a uniform system for handling nested containers of array values) and its functional API works easily on these containers. Let us consider an example where we gathered our example network parameters into a dictionary:" + "Consider an example where we gather our example network parameters into a dictionary:" ] }, { @@ -945,18 +951,22 @@ "id": "65a935c3", "metadata": {}, "source": [ - "The functional API in JAX easily allows us to compute higher order gradients by calling `jax.grad` multiple times on the function. We will not cover this topic in this tutorial, for more details we suggest reading [JAX automatic differentiation tutorial](https://jax.readthedocs.io/en/latest/automatic-differentiation.html).\n", + "The functional API in JAX easily allows us to compute higher order gradients by calling `jax.grad` multiple times on the function.\n", + "\n", + "### Further reading\n", + "\n", + "For more details, check out the [Automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) tutorial on JAX's documentation site.\n", "\n", "## Build and train a model\n", "\n", "\n", - "In this section we will learn how to build a simple model using Flax ([`flax.nnx` API](https://flax.readthedocs.io/en/latest/nnx_basics.html)) and optimize its parameters using training data provided by PyTorch dataloader.\n", + "In this section we will learn how to build a simple model using [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html)) and optimize its parameters using training data provided by PyTorch `DataLoader`s.\n", "\n", "\n", - "Model creation with Flax is very similar to PyTorch using the `torch.nn` module. In this example, we will build the ResNet18 model.\n", + "Model creation with the Flax API (`nnx.Module`) is very similar to PyTorch's `torch.nn` module. In this example, we will build the ResNet18 model.\n", "\n", "\n", - "### Build ResNet18 model" + "### Build a ResNet18 model" ] }, { @@ -965,7 +975,7 @@ "metadata": {}, "outputs": [], "source": [ - "# To install Flax: `pip install -U flax treescope optax`\n", + "# To install Flax use `pip install -U flax treescope optax`\n", "import jax\n", "import jax.numpy as jnp\n", "from flax import nnx\n", @@ -1090,7 +1100,7 @@ "source": [ "model = ResNet18(10, rngs=nnx.Rngs(0))\n", "\n", - "# Visualize the model architecture\n", + "# Visualize the model architecture.\n", "nnx.display(model)" ] }, @@ -1132,15 +1142,15 @@ "Note that the input array is explicitly in the channels-last memory format. In PyTorch, the typical input tensor to a neural network has channels-first memory format and has shape `(4, 3, 32, 32)` by default.\n", "\n", "\n", - "### Dataflow using Torchvision and PyTorch data loaders\n", + "### Data flow using `torchvision` and PyTorch `DataLoader`s\n", "\n", "\n", "Let us now set up training and test data using the CIFAR10 dataset from `torchvision`.\n", - "We will create torch dataloaders with collate functions returning NumPy Arrays instead of PyTorch tensors.\n", - "Since JAX is a multithreaded framework, using it in multiple processes can cause issues. For this reason, we will avoid creating JAX Arrays in the dataloaders.\n", + "We will create `torch` `DataLoader`s with collate functions returning NumPy arrays instead of PyTorch `Tensor`s.\n", + "Since JAX is a multithreaded framework, using it in multiple processes can cause issues. For this reason, we will avoid creating `JAX.Arrays` in the `DataLoader`s.\n", "\n", "\n", - "As an alternative, one can use [grain](https://github.com/google/grain/tree/main) for data loading and [PIX](https://github.com/google-deepmind/dm_pix) for image data augmentations." + "Alternatively, we can use [grain](https://github.com/google/grain/tree/main) for data loading and [PIX](https://github.com/google-deepmind/dm_pix) for image data augmentations." ] }, { @@ -1157,7 +1167,7 @@ } ], "source": [ - "# CIFAR10 training/testing datasets setup\n", + "# Setting up the CIFAR10 training/testing datasets.\n", "import numpy as np\n", "\n", "from torchvision.transforms import v2 as T\n", @@ -1198,7 +1208,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Data loaders setup\n", + "# Setting up the `DataLoader`s.\n", "from torch.utils.data import DataLoader\n", "\n", "\n", @@ -1241,7 +1251,7 @@ } ], "source": [ - "# Let us check training dataloader:\n", + "# Check the training `DataLoader`:\n", "trl_iter = iter(train_loader)\n", "batch = next(trl_iter)\n", "print(batch[0].shape, batch[0].dtype, batch[1].shape, batch[1].dtype)" @@ -1251,14 +1261,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Note: when executing the code above you may see this warning: `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.`. This warning can be ignored as dataloaders are not using JAX in forked processes.\n", + "**Note**: When executing the code above you may get a warning message that says: `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.`. This warning can be ignored as `DataLoader`s are not using JAX in forked processes.\n", "\n", "\n", "### Model training\n", "\n", - "\n", - "Let us now define the optimizer, loss function, train and test steps using Flax API.\n", - "PyTorch users can find the code using Flax NNX API very similar to PyTorch." + "Let's define the optimizer, the loss function, and custom `train_step` and `eval_step` functions using the Flax API.\n", + "PyTorch users may find the Flax API code below similar to PyTorch." ] }, { @@ -1327,7 +1336,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Define helper object to compute train/test metrics\n", + "# Define a helper object to compute train/test metrics.\n", "metrics = nnx.MultiMetric(\n", " accuracy=nnx.metrics.Accuracy(),\n", " loss=nnx.metrics.Average('loss'),\n", @@ -1409,7 +1418,7 @@ } ], "source": [ - "# Start the training\n", + "# Start the training.\n", "\n", "num_epochs = 3\n", "\n", @@ -1457,18 +1466,15 @@ "source": [ "### Further reading\n", "\n", - "More details about Flax NNX API, how to save and load the model's state and about available optimizers, we suggest to check out the links below:\n", - "- [FLAX NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html)\n", - "- [Save & Load model's state](https://flax.readthedocs.io/en/latest/guides/checkpointing.html)\n", - "- [Optax](https://optax.readthedocs.io/en/latest/)\n", + "You can learn more about Flax and Optax in:\n", "\n", - "\n", - "Other AI/ML tutorials to check out:\n", - "- [JAX AI Stack tutorials](https://jax-ai-stack.readthedocs.io/en/latest/tutorials.html)\n", + "- [Flax NNX basics](https://flax.readthedocs.io/en/latest/nnx_basics.html)\n", + "- [Save and load checkpoints](https://flax.readthedocs.io/en/latest/guides/checkpointing.html) on Flax's documentation site\n", + "- [Optax's documentation site](https://optax.readthedocs.io/en/latest/)\n", + "- [JAX AI Stack tutorials](https://jax-ai-stack.readthedocs.io/en/latest/tutorials.html) for additional machine learning content\n", "\n", "## Just-In-Time (JIT) compilation in JAX\n", "\n", - "\n", "PyTorch users know very well about the eager mode execution of the operations in PyTorch, e.g. the operations are executed one by one without any high-level optimizations on sets of operations. Similarly, almost everywhere in this tutorial we used JAX in the eager mode as well.\n", "\n", "\n", @@ -1554,7 +1560,7 @@ "source": [ "### Further reading\n", "\n", - "- [JAX documentation on Just-in-time compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html)" + "- [Just-in-time compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html) with `jax.jit` on JAX's documentation site." ] } ], diff --git a/docs/source/JAX_for_PyTorch_users.md b/docs/source/JAX_for_PyTorch_users.md index 7e2cef7..5774570 100644 --- a/docs/source/JAX_for_PyTorch_users.md +++ b/docs/source/JAX_for_PyTorch_users.md @@ -18,15 +18,18 @@ kernelspec: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_PyTorch_users.ipynb) -This is a quick overview of JAX and the JAX AI stack written for those who are famiilar with PyTorch. +This tutorial provides a quick overview of JAX and JAX-based libraries (the JAX AI stack) for PyTorch users. You will how to: -First, we cover how to manipulate JAX Arrays following the [well-known PyTorch's tensors tutorial](https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html). Next, we explore automatic differentiation with JAX, followed by how to build a model and optimize its parameters. -Finally, we will introduce `jax.jit` and compare it to its PyTorch counterpart `torchscript`. +- Manipulate [`jax.Array`s](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array), following PyTorch's [Tensors](https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html) tutorial. +- Explore automatic differentiation with JAX with [`jax.grad`](https://jax.readthedocs.io/en/latest/automatic-differentiation.html), following the [autodiff PyTorch tutorial](https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html). +- Build a neural network and optimize its parameters with [Flax](https://flax.readthedocs.io/en/latest/) and [Optax](https://optax.readthedocs.io/en/latest/). +- Use [`jax.jit`](https://jax.readthedocs.io/en/latest/jit-compilation.html) for just-in-time compilation, comparing it PyTorch's `torchscript`. ## Setup -Let's get started by importing JAX and checking the installed version. -For details on how to install JAX check [installation guide](https://jax.readthedocs.io/en/latest/installation.html). +JAX installation is covered in [this guide](https://jax.readthedocs.io/en/latest/installation.html) on the JAX documentation site. + +Import JAX and JAX NumPy, and check the installed version. ```{code-cell} ipython3 --- @@ -42,17 +45,17 @@ print(jax.__version__) +++ {"id": "LNBvB_hRDteB"} -## JAX Arrays manipulation +## `jax.Array` manipulation vs `torch.Tensor`s + +This section covers [`jax.Array`s](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html) - JAX's primary array object - and how to manipulate them. `jax.Array` is the JAX counterpart of [PyTorch's `torch.Tensor`](https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html). -In this section, we will learn about JAX Arrays and how to manipulate them compared to PyTorch tensors. +### Initializing a `jax.Array` -### Initializing a JAX Array +Similar to `torch.Tensor`s, [`jax.Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html) objects are never constructed directly. Instead, they are constructed via array creation APIs that populate the new array with constant numbers, random numbers, or data drawn from Python lists, NumPy arrays, `torch.Tensor`s, and so on. -The primary array object in JAX is the `jax.Array`, which is the JAX counterpart of `torch.Tensor`. -As with `torch.Tensor`, `jax.Array` objects are never constructed directly, but rather constructed via array creation APIs that populate the new array with constant numbers, random numbers, or data drawn from lists, numpy arrays, torch tensors, and more. -Let's see some examples of this. +Let's go through some examples. -To initialize an array from a Python data: +To initialize an array from Python data: ```{code-cell} ipython3 --- @@ -61,14 +64,14 @@ colab: id: 9J4m79evD0fJ outputId: 7b8196fb-4f16-4c26-864f-e3c08697fe19 --- -# From data +# From data: data = [[1, 2, 3], [3, 4, 5]] x_array = jnp.array(data) assert isinstance(x_array, jax.Array) print(x_array, x_array.shape, x_array.dtype) ``` -Or from an existing NumPy array: +To initialize from an existing NumPy array: ```{code-cell} ipython3 --- @@ -83,10 +86,10 @@ np_array = np.array(data) x_np = jnp.array(np_array) assert isinstance(x_np, jax.Array) print(x_np, x_np.shape, x_np.dtype) -# x_np is a copy of np_array +# x_np is a copy of np_array. ``` -You can create arrays with the same shape and `dtype` as existing JAX Arrays: +We can create arrays with the same shape and `dtype` as existing `jax.Array`s: ```{code-cell} ipython3 --- @@ -102,7 +105,7 @@ x_zeros = jnp.zeros_like(x_array) print(x_zeros, x_zeros.shape, x_zeros.dtype) ``` -You can even initialize arrays with constants or random values. For example: +We can also initialize arrays with constants or random values. For example: ```{code-cell} ipython3 --- @@ -124,8 +127,8 @@ print(f"Ones Tensor: \n {ones_tensor} \n") print(f"Zeros Tensor: \n {zeros_tensor}") ``` -JAX avoids implicit global random state and instead tracks state explicitly via a random `key`. -If we create two random arrays using the same `key` we will obtain two identical random arrays. +JAX avoids implicit global random state and instead [tracks state explicitly via a random `key`](https://jax.readthedocs.io/en/latest/random-numbers.html#explicit-random-state). +If we create two random arrays using the same `key`, we will obtain two identical random arrays. We can also split the random `key` into multiple keys to create two different random arrays. ```{code-cell} ipython3 @@ -141,11 +144,11 @@ rand_tensor2 = jax.random.uniform(k2, (2, 3)) assert (rand_tensor1 != rand_tensor2).all() ``` -For further discussion on random numbers in NumPy and JAX check [this tutorial](https://jax.readthedocs.io/en/latest/random-numbers.html). +**Note:** Learn more about [`jax.random`](https://jax.readthedocs.io/en/latest/jax.random.html#module-jax.random) and pseudorandom number generation (PRNG) in JAX in [this tutorial](https://jax.readthedocs.io/en/latest/random-numbers.html) on the JAX documentation site. +++ -Finally, if you have a PyTorch tensor, you can use it to initialize a JAX Array: +And if you have a PyTorch `tensor`, you can also use it to initialize a `jax.Array`: ```{code-cell} ipython3 --- @@ -158,22 +161,23 @@ import torch x_torch = torch.rand(3, 4) -# Create JAX Array as a copy of x_torch tensor +# Create a `jax.Array` as a copy of the `x_torch` tensor +# using `jax.numpy.asarray()`. x_jax = jnp.asarray(x_torch) assert isinstance(x_jax, jax.Array) print(x_jax, x_jax.shape, x_jax.dtype) -# Use dlpack to create JAX Array without copying +# Use `jax.dlpack.from_dlpack()` to create a `jax.Array` without copying. x_jax = jax.dlpack.from_dlpack(x_torch.to(device="cuda"), copy=False) print(x_jax, x_jax.shape, x_jax.dtype) ``` +++ {"id": "oTXSGITNNnuY"} -### Attributes of a JAX Array +### Attributes of a `jax.Array` -Similarly to PyTorch tensors, JAX Array attributes describe the array's shape, dtype and device: +Similar to `torch.Tensor`s, `jax.Array` attributes describe the array's shape, dtype and device: ```{code-cell} ipython3 --- @@ -190,10 +194,10 @@ print(f"Device tensor is stored on: {x_jax.device}") +++ {"id": "S4CxmHSaKz-r"} -However, there are some notable differences between PyTorch tensors and JAX Arrays: -- JAX Arrays are immutable -- The default integer and float dtypes are int32 and float32 -- The default device corresponds to the available accelerator, e.g. cuda:0 if one or multiple GPUs are available. +However, there are some notable differences between `torch.Tensor`s and `jax.Array`s: +- `jax.Array`s are immutable. +- The default integer and float dtypes are `int32` and `float32`. +- The default device corresponds to the available accelerator, e.g. `cuda:0` if one or multiple GPUs are available. ```{code-cell} ipython3 --- @@ -220,13 +224,13 @@ print(f"Default devices, PyTorch: {x_torch.device} and Jax: {x_jax.device}") +++ {"id": "1o1u2N1lL7I3"} -For some discussion of JAX's alternative to in-place mutation, refer to https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html. +To learn more about JAX's alternative to in-place mutation, refer to the [`jax.numpy.ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) API documentation on JAX's site. +++ {"id": "_x7LutxoC3Eq"} ### Devices and accelerators -Using the PyTorch API, we can check whether we have GPU accelerators available with `torch.cuda.is_available()`. In JAX, we can check available devices as follows: +Using the PyTorch API, we can check whether we have GPU accelerators available with `torch.cuda.is_available()`. In JAX, we can check available devices using [`jax.devices`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices): ```{code-cell} ipython3 --- @@ -236,7 +240,7 @@ id: s0k84UEMQSwL outputId: 71b3658b-45ee-4fe1-a972-986f2e0da950 --- print(f"Available devices given a backend (gpu or tpu or cpu): {jax.devices()}") -# Define CPU and CUDA devices +# Define CPU and CUDA devices. cpu_device = jax.devices("cpu")[0] cuda_device = jax.devices("cuda")[0] print(cpu_device, cuda_device) @@ -253,15 +257,15 @@ colab: id: 9ZSyG8Q6Q5Jw outputId: 68da1b2b-6803-4382-c851-b1c1062c60dd --- -# create an array on CPU and check the device +# Create an array on CPU and check the device. x_cpu = jnp.ones((3, 4), device=cpu_device) print(x_cpu.device, ) -# create an array on GPU +# Create an array on GPU. x_gpu = jnp.ones((3, 4), device=cuda_device) print(x_gpu.device) ``` -In PyTorch we are used to device placement always being explicit. JAX can operate this way via explicit device placement as above, but unless the device is specified the array will remain *uncommitted*: i.e. it will be stored on the default device, but allow implicit movement to other devices when necessary: +In PyTorch device placement is explicit. JAX can operate this way via _explicit device placement_ as above. But, unless the device is specified, the array will remain *uncommitted*, which means it will be stored on the _default device_, while allowing implicit movement to other devices when necessary: ```{code-cell} ipython3 x = jnp.ones((3, 4)) @@ -269,7 +273,7 @@ x = jnp.ones((3, 4)) x.device, (x_cpu + x).device ``` -However, if we make a computation with two arrays with explicitly specified devices, e.g. CPU and CUDA, similarly to PyTorch, an error will be raised. +However, if we make a computation with two arrays with explicitly specified devices, such as CPU and CUDA, similar to PyTorch, an error will be raised. ```{code-cell} ipython3 try: @@ -278,7 +282,7 @@ except ValueError as e: print(e) ``` -To move from one device to another, we can use `jax.device_put` function: +To move from one `device` to another, we can use the [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html#jax.device_put) function: ```{code-cell} ipython3 x = jnp.ones((3, 4)) @@ -289,7 +293,7 @@ print(f"{x.device} -> {x_cpu.device} -> {x_cuda.device}") +++ {"id": "o2481PfoPdFG"} -### Operations on JAX Arrays +### Operations on `jax.Array`s There is a large list of operations (arithmetics, linear algebra, matrix manipulation, etc) that can be directly performed on JAX Arrays. JAX API contains important modules: - `jax.numpy` provides NumPy-like functions @@ -300,7 +304,7 @@ There is a large list of operations (arithmetics, linear algebra, matrix manipul More details on available ops can be found in the [API reference](https://jax.readthedocs.io/en/latest/jax.html). -All operations can be run on CPUs, GPUs or TPUs. By default, JAX Arrays are created on an accelerated device, while PyTorch tensors are created on CPUs. +All operations can be run on CPUs, GPUs or TPUs. By default, `jaxArray`s are created on an accelerated device, while PyTorch `torch.Tensor`s are created on CPUs. We can now try out some array operations and check for similarities between the JAX, NumPy and PyTorch APIs. @@ -320,19 +324,19 @@ print(f"First row: {tensor[0]}") print(f"First column: {tensor[:, 0]}") print(f"Last column: {tensor[..., -1]}") -# Equivalent PyTorch op: tensor[:, 1] = 0 +# Equivalent PyTorch op: `tensor[:, 1] = 0`. tensor = tensor.at[:, 1].set(0) print(tensor) ``` -We would like to note particular out-of-bounds indexing behaviour in JAX. In JAX the index is clamped to the bounds of the array in the indexing operations. +Note the particular out-of-bounds indexing behaviour in JAX: In JAX the index is clamped to the bounds of the array in the indexing operations. ```{code-cell} ipython3 print(jnp.arange(10)[11]) ``` -Join arrays similar to `torch.cat`. Note the kwarg name: `axis` vs `dim`. +**Join arrays similar to `torch.cat`:** Note the kwarg name: `axis` vs `dim`. ```{code-cell} ipython3 --- @@ -345,25 +349,25 @@ t1 = jnp.concat([tensor, tensor, tensor], axis=1) print(t1) ``` -Arithmetic operations. Operations below compute the matrix multiplication between two tensors. y1, y2 will have the same value. +**Arithmetic operations:** Operations below compute the matrix multiplication between two tensors. `y1`, `y2` will have the same value: ```{code-cell} ipython3 :id: P8jcElVyYTp7 -# ``tensor.T`` returns the transpose of a tensor +# `tensor.T` returns the transpose of a tensor. y1 = tensor @ tensor.T y2 = jnp.matmul(tensor, tensor.T) assert (y1 == y2).all() -# This computes the element-wise product. z1, z2 will have the same value +# This computes the element-wise product. `z1`, `z2` will have the same value. z1 = tensor * tensor z2 = jnp.multiply(tensor, tensor) assert (z1 == z2).all() ``` -Single-element arrays. If you have a one-element array, for example by aggregating all values of a tensor into one value, you can convert it to a Python numerical value using `.item()`: +**Single-element arrays:** If you have a single/one-element array, for example by aggregating all values of a tensor into one value, you can convert it to a Python numerical value using `.item()`: ```{code-cell} ipython3 --- @@ -395,11 +399,11 @@ tensor.sigmoid(), tensor.softmax(dim=1), tensor.sin(), # ... ## Automatic differentiation with JAX -In this section, we will learn about the fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general autodiff system, and its API has inspired the `torch.func` module in PyTorch, previously known as “functorch” (JAX-like composable function transforms for PyTorch). +This section covers the basics of the [automatic differentiation (autodiff)](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) in JAX. JAX has a pretty general autodiff system, and its API has inspired the `torch.func` module in PyTorch, previously known as “functorch” (JAX-like composable function transforms for PyTorch). -In PyTorch, there is an API to turn on the automatic operations graph recording (e.g., `required_grad` argument and `tensor.backward()`), but in JAX, automatic differentiation is a functional operation, i.e., there is no need to mark arrays with a flag to enable gradient tracking. +In PyTorch, there is an API to turn on the automatic operations graph recording (for example, the `required_grad` argument and `tensor.backward()`). In JAX, automatic differentiation is a functional operation, which means there is no need to mark arrays with a flag to enable gradient tracking. -Let us follow [autodiff PyTorch tutorial](https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html) and consider the simplest one-layer neural network, with input `x`, parameters `w` and `b`, and some loss function. In JAX, this can be defined in the following way: +Let's follow PyTorch's [Automatic differentiation with `torch.autograd`](https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html) tutorial and consider the simplest one-layer neural network, with input `x`, parameters `w` and `b`, and some loss function. In JAX, this can be defined in the following way: ```{code-cell} ipython3 --- @@ -412,12 +416,12 @@ import jax import jax.numpy as jnp -# Input tensor +# Input tensor. x = jnp.ones(5) -# Target output +# Target output. y_true = jnp.zeros(3) -# Initialize random parameters +# Initialize random parameters. seed = 123 key = jax.random.key(seed) key, w_key, b_key = jax.random.split(key, 3) @@ -425,11 +429,11 @@ w = jax.random.normal(w_key, (5, 3)) b = jax.random.normal(b_key, (3, )) -# model function +# Model function. def predict(x, w, b): return jnp.matmul(x, w) + b -# Criterion or loss function +# Criterion or loss function. def compute_loss(w, b, x, y_true): y_pred = predict(x, w, b) return jnp.mean((y_true - y_pred) ** 2) @@ -441,7 +445,7 @@ print(loss) +++ {"id": "k3r8od6LtD_9"} -In our example network, `w` and `b` are parameters to optimize and we need to be able to compute the gradients of the loss function with respect to those variables. In order to do that, we use [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) function on `compute_loss` function: +In our example network, `w` and `b` are parameters to optimize, and we need to be able to compute the gradients of the loss function with respect to those variables. To do that, we can use the [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) transformation on the `compute_loss` function as follows: ```{code-cell} ipython3 --- @@ -450,14 +454,14 @@ colab: id: glqqg02VvP-W outputId: 549934e4-823b-48e1-c7f0-04ecd2b6c0d5 --- -# Differentiate `compute_loss` with respect to the 0 and 1 positional arguments: +# Differentiate `compute_loss` with respect to the `0` and `1` positional arguments. w_grad, b_grad = jax.grad(compute_loss, argnums=(0, 1))(w, b, x, y_true) print(f'{w_grad=}') print(f'{b_grad=}') ``` ```{code-cell} ipython3 -# Compute w_grad, b_grad and loss value: +# Compute `w_grad`, `b_grad` and `loss_value`. loss_value, (w_grad, b_grad) = jax.value_and_grad(compute_loss, argnums=(0, 1))(w, b, x, y_true) print(f'{w_grad=}') print(f'{b_grad=}') @@ -465,10 +469,12 @@ print(f'{loss_value=}') print(f'{compute_loss(w, b, x, y_true)=}') ``` -### `jax.grad` and PyTrees +### `jax.grad` and pytrees + +JAX uses the [pytree abstraction](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#working-with-pytrees), such as Python containers like dicts, tuples, lists, etc which provides a uniform system for handling nested containers of array values. And JAX's functional API works easily on these containers. -JAX introduced the [PyTree abstraction](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#working-with-pytrees)(e.g. Python containers like dicts, tuples, lists, etc which provides a uniform system for handling nested containers of array values) and its functional API works easily on these containers. Let us consider an example where we gathered our example network parameters into a dictionary: +Consider an example where we gather our example network parameters into a dictionary: ```{code-cell} ipython3 net_params = { @@ -487,21 +493,25 @@ def compute_loss2(net_params, x, y_true): jax.value_and_grad(compute_loss2, argnums=0)({"weights": w, "bias": b}, x, y_true) ``` -The functional API in JAX easily allows us to compute higher order gradients by calling `jax.grad` multiple times on the function. We will not cover this topic in this tutorial, for more details we suggest reading [JAX automatic differentiation tutorial](https://jax.readthedocs.io/en/latest/automatic-differentiation.html). +The functional API in JAX easily allows us to compute higher order gradients by calling `jax.grad` multiple times on the function. + +### Further reading + +For more details, check out the [Automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) tutorial on JAX's documentation site. ## Build and train a model -In this section we will learn how to build a simple model using Flax ([`flax.nnx` API](https://flax.readthedocs.io/en/latest/nnx_basics.html)) and optimize its parameters using training data provided by PyTorch dataloader. +In this section we will learn how to build a simple model using [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html)) and optimize its parameters using training data provided by PyTorch `DataLoader`s. -Model creation with Flax is very similar to PyTorch using the `torch.nn` module. In this example, we will build the ResNet18 model. +Model creation with the Flax API (`nnx.Module`) is very similar to PyTorch's `torch.nn` module. In this example, we will build the ResNet18 model. -### Build ResNet18 model +### Build a ResNet18 model ```{code-cell} ipython3 -# To install Flax: `pip install -U flax treescope optax` +# To install Flax use `pip install -U flax treescope optax` import jax import jax.numpy as jnp from flax import nnx @@ -596,7 +606,7 @@ class ResNet18(nnx.Module): model = ResNet18(10, rngs=nnx.Rngs(0)) -# Visualize the model architecture +# Visualize the model architecture. nnx.display(model) ``` @@ -613,18 +623,18 @@ y_pred.shape Note that the input array is explicitly in the channels-last memory format. In PyTorch, the typical input tensor to a neural network has channels-first memory format and has shape `(4, 3, 32, 32)` by default. -### Dataflow using Torchvision and PyTorch data loaders +### Data flow using `torchvision` and PyTorch `DataLoader`s Let us now set up training and test data using the CIFAR10 dataset from `torchvision`. -We will create torch dataloaders with collate functions returning NumPy Arrays instead of PyTorch tensors. -Since JAX is a multithreaded framework, using it in multiple processes can cause issues. For this reason, we will avoid creating JAX Arrays in the dataloaders. +We will create `torch` `DataLoader`s with collate functions returning NumPy arrays instead of PyTorch `Tensor`s. +Since JAX is a multithreaded framework, using it in multiple processes can cause issues. For this reason, we will avoid creating `JAX.Arrays` in the `DataLoader`s. -As an alternative, one can use [grain](https://github.com/google/grain/tree/main) for data loading and [PIX](https://github.com/google-deepmind/dm_pix) for image data augmentations. +Alternatively, we can use [grain](https://github.com/google/grain/tree/main) for data loading and [PIX](https://github.com/google-deepmind/dm_pix) for image data augmentations. ```{code-cell} ipython3 -# CIFAR10 training/testing datasets setup +# Setting up the CIFAR10 training/testing datasets. import numpy as np from torchvision.transforms import v2 as T @@ -660,7 +670,7 @@ test_dataset = CIFAR10("./data", train=True, download=False, transform=test_tran ``` ```{code-cell} ipython3 -# Data loaders setup +# Setting up the `DataLoader`s. from torch.utils.data import DataLoader @@ -682,20 +692,19 @@ test_loader = DataLoader( ``` ```{code-cell} ipython3 -# Let us check training dataloader: +# Check the training `DataLoader`: trl_iter = iter(train_loader) batch = next(trl_iter) print(batch[0].shape, batch[0].dtype, batch[1].shape, batch[1].dtype) ``` -Note: when executing the code above you may see this warning: `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.`. This warning can be ignored as dataloaders are not using JAX in forked processes. +**Note**: When executing the code above you may get a warning message that says: `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.`. This warning can be ignored as `DataLoader`s are not using JAX in forked processes. ### Model training - -Let us now define the optimizer, loss function, train and test steps using Flax API. -PyTorch users can find the code using Flax NNX API very similar to PyTorch. +Let's define the optimizer, the loss function, and custom `train_step` and `eval_step` functions using the Flax API. +PyTorch users may find the Flax API code below similar to PyTorch. ```{code-cell} ipython3 import optax @@ -743,7 +752,7 @@ def eval_step(model: nnx.Module, metrics: nnx.MultiMetric, batch): Readers may note the [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator of `train_step` and `eval_step` methods which is used to jit-compile the functions. JIT compilation in JAX is explored in the last section of this tutorial. ```{code-cell} ipython3 -# Define helper object to compute train/test metrics +# Define a helper object to compute train/test metrics. metrics = nnx.MultiMetric( accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average('loss'), @@ -758,7 +767,7 @@ metrics_history = { ``` ```{code-cell} ipython3 -# Start the training +# Start the training. num_epochs = 3 @@ -802,18 +811,15 @@ for epoch in range(num_epochs): ### Further reading -More details about Flax NNX API, how to save and load the model's state and about available optimizers, we suggest to check out the links below: -- [FLAX NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html) -- [Save & Load model's state](https://flax.readthedocs.io/en/latest/guides/checkpointing.html) -- [Optax](https://optax.readthedocs.io/en/latest/) +You can learn more about Flax and Optax in: - -Other AI/ML tutorials to check out: -- [JAX AI Stack tutorials](https://jax-ai-stack.readthedocs.io/en/latest/tutorials.html) +- [Flax NNX basics](https://flax.readthedocs.io/en/latest/nnx_basics.html) +- [Save and load checkpoints](https://flax.readthedocs.io/en/latest/guides/checkpointing.html) on Flax's documentation site +- [Optax's documentation site](https://optax.readthedocs.io/en/latest/) +- [JAX AI Stack tutorials](https://jax-ai-stack.readthedocs.io/en/latest/tutorials.html) for additional machine learning content ## Just-In-Time (JIT) compilation in JAX - PyTorch users know very well about the eager mode execution of the operations in PyTorch, e.g. the operations are executed one by one without any high-level optimizations on sets of operations. Similarly, almost everywhere in this tutorial we used JAX in the eager mode as well. @@ -853,4 +859,4 @@ jit_matmul_relu(x, y) ### Further reading -- [JAX documentation on Just-in-time compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html) +- [Just-in-time compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html) with `jax.jit` on JAX's documentation site.