10 JAX Concepts, Explained Simply
Here’s everything you need to get started with JAX, Google’s high-performance numerical computing library.
JAX is a rapidly growing high-performance numerical computation library in Python, developed by Google and NVIDIA.
There’s an evolving ecosystem of machine learning/ deep learning tools around JAX, the notable ones being:
Flax: A library for building neural networks
Optax: A gradient processing and optimization library
Equinox: A library with PyTorch-like syntax for building neural networks
Haiku: A library that provides simple, composable abstractions for building neural networks
Jraph: A library for building Graph neural networks
RLax: A library for building RL agents
Chex: A library of utilities for writing, test and debugging reliable JAX code
Orbax: Checkpointing and model persistence library for saving and loading the training state of ML models
I am sure that many of you have used NumPy and PyTorch for machine learning, but JAX might be unfamiliar to you.
Let’s change this today and help you get started.
Before we start, I want to introduce you to my book, ‘LLMs In 100 Images’.
It is a collection of 100 easy-to-follow visuals that explain the most important concepts you need to master to understand LLMs today.
Grab your copy today at a special 20% discount using this link.
1. JAX Works With Pure Functions
JAX requires its functions to be pure. This means that they must have no side effects and always produce the same output for the same input.
This concept comes from a programming paradigm called Functional Programming.
This is necessary for the transformations that JAX performs internally. Using impure functions might throw errors or silently fail, resulting in incorrect results.
# Pure function for adding two numbers
def pure_addition(a, b):
return a + b
# Impure function for adding two numbers (Won't work with JAX)
counter = 0
def impure_addition(a, b):
global counter
counter += 1
return a + b2. JAX’s NumPy is Similar to NumPy (But Not Always)
JAX provides a NumPy-like interface for computations that can automatically and efficiently run on CPU, GPU, or TPU, in local or distributed settings.
For this, JAX uses the XLA (Accelerated Linear Algebra) compiler to translate JAX code into optimised machine code for different hardware.
This differs from NumPy, where operations run on the CPU by default.
The syntax of JAX NumPy is very similar to NumPy.
# An example with NumPy
import numpy as np
print(np.sqrt(4))# Similar example with JAX NumPy
import jax.numpy as jnp
print(jnp.sqrt(4))Some other examples that look similar in both JAX NumPy and NumPy are shown below.
import numpy as np
import jax.numpy as jnp
# Create array
np_a = np.array([1.0, 2.0, 3.0])
jnp_a = jnp.array([1.0, 2.0, 3.0])
# Element-wise arithmetic operations
print(np_a + 2)
print(jnp_a + 2)
# Broadcasting
np_b = np.array([[1, 2, 3]])
jnp_b = jnp.array([[1, 2, 3]])
print(np_b + np.arange(3))
print(jnp_b + jnp.arange(3))
# Sum
print(np.sum(np_a))
print(jnp.sum(jnp_a))
# Mean
print(np.mean(np_a))
print(jnp.mean(jnp_a))
# Linear algebra (Dot product)
print(np.dot(np_a, np_a))
print(jnp.dot(jnp_a, jnp_a))However, there are a few differences between them, and one major difference is that:
JAX arrays are immutable, and operations on existing arrays return new arrays.
import jax.numpy as jnp
# Create an array
x = jnp.array([1, 2, 3])
# Try to modify the array
x[0] = 10 # Throws an errorThis is not the case with NumPy, where arrays are mutable.
import numpy as np
# Create an array
x = np.array([1, 2, 3])
# Modify the array
x[0] = 10To modify an array in JAX, we use the following method which returns a new array with the changes applied.
# 'y' is the new value that we can set at index 'idx' of the array 'x'
# The result is a new array 'z'
z = x.at[idx].set(y)# Create an array
x = jnp.array([1, 2, 3])
y = x.at[0].set(10)
print(y) # [10, 2, 3]
print(x) # [1, 2, 3] (unchanged)3. JIT Compilation
Just-In-Time compilation is performed in JAX using XLA.
XLA compiles Python/JAX code into optimized machine code for faster execution.
from jax import jit
# Non JIT-compiled function
def square(x):
return x * x
# JIT-compiled function
@jit
def jit_square(x):
return x * xThe square function shown above is executed by the Python interpreter and is slow to run.
On the other hand, the jit_square function that uses the @jit decorator runs extremely fast due to JIT compilation.
When the function is first called, the JIT engine:
Traces the function and builds an optimized computation graph
The graph is compiled to optimized XLA (low-level) code.
The result is cached.
The cached version of the function is used for subsequent calls.
Keep reading with a 7-day free trial
Subscribe to Into AI to keep reading this post and get 7 days of free access to the full post archives.



