Into AI

Into AI

10 JAX Concepts, Explained Simply

Here’s everything you need to get started with JAX, Google’s high-performance numerical computing library.

Dr. Ashish Bamania's avatar
Dr. Ashish Bamania
Nov 18, 2025
∙ Paid

JAX is a rapidly growing high-performance numerical computation library in Python, developed by Google and NVIDIA.

There’s an evolving ecosystem of machine learning/ deep learning tools around JAX, the notable ones being:

  • Flax: A library for building neural networks

  • Optax: A gradient processing and optimization library

  • Equinox: A library with PyTorch-like syntax for building neural networks

  • Haiku: A library that provides simple, composable abstractions for building neural networks

  • Jraph: A library for building Graph neural networks

  • RLax: A library for building RL agents

  • Chex: A library of utilities for writing, test and debugging reliable JAX code

  • Orbax: Checkpointing and model persistence library for saving and loading the training state of ML models

I am sure that many of you have used NumPy and PyTorch for machine learning, but JAX might be unfamiliar to you.

Let’s change this today and help you get started.

Become a paid subscriber today


Before we start, I want to introduce you to my book, ‘LLMs In 100 Images’.

It is a collection of 100 easy-to-follow visuals that explain the most important concepts you need to master to understand LLMs today.

Grab your copy today at a special 20% discount using this link.


1. JAX Works With Pure Functions

JAX requires its functions to be pure. This means that they must have no side effects and always produce the same output for the same input.

This concept comes from a programming paradigm called Functional Programming.

This is necessary for the transformations that JAX performs internally. Using impure functions might throw errors or silently fail, resulting in incorrect results.

# Pure function for adding two numbers
def pure_addition(a, b):
  return a + b

# Impure function for adding two numbers (Won't work with JAX)
counter = 0

def impure_addition(a, b):
  global counter
  counter += 1
  return a + b

2. JAX’s NumPy is Similar to NumPy (But Not Always)

JAX provides a NumPy-like interface for computations that can automatically and efficiently run on CPU, GPU, or TPU, in local or distributed settings.

For this, JAX uses the XLA (Accelerated Linear Algebra) compiler to translate JAX code into optimised machine code for different hardware.

This differs from NumPy, where operations run on the CPU by default.

The syntax of JAX NumPy is very similar to NumPy.

# An example with NumPy
import numpy as np

print(np.sqrt(4))
# Similar example with JAX NumPy
import jax.numpy as jnp

print(jnp.sqrt(4))

Some other examples that look similar in both JAX NumPy and NumPy are shown below.

import numpy as np
import jax.numpy as jnp

# Create array
np_a = np.array([1.0, 2.0, 3.0])
jnp_a = jnp.array([1.0, 2.0, 3.0])

# Element-wise arithmetic operations
print(np_a + 2)
print(jnp_a + 2)

# Broadcasting
np_b = np.array([[1, 2, 3]])
jnp_b = jnp.array([[1, 2, 3]])

print(np_b + np.arange(3))
print(jnp_b + jnp.arange(3))

# Sum
print(np.sum(np_a))
print(jnp.sum(jnp_a))

# Mean
print(np.mean(np_a)) 
print(jnp.mean(jnp_a))

# Linear algebra (Dot product)
print(np.dot(np_a, np_a)) 
print(jnp.dot(jnp_a, jnp_a))

However, there are a few differences between them, and one major difference is that:

JAX arrays are immutable, and operations on existing arrays return new arrays.

import jax.numpy as jnp

# Create an array
x = jnp.array([1, 2, 3])

# Try to modify the array
x[0] = 10  # Throws an error

This is not the case with NumPy, where arrays are mutable.

import numpy as np

# Create an array
x = np.array([1, 2, 3])

# Modify the array
x[0] = 10

To modify an array in JAX, we use the following method which returns a new array with the changes applied.

# 'y' is the new value that we can set at index 'idx' of the array 'x'
# The result is a new array 'z'

z = x.at[idx].set(y)
# Create an array
x = jnp.array([1, 2, 3])

y = x.at[0].set(10)

print(y) # [10, 2, 3]
print(x) # [1, 2, 3] (unchanged)

Become a paid subscriber today


3. JIT Compilation

Just-In-Time compilation is performed in JAX using XLA.

XLA compiles Python/JAX code into optimized machine code for faster execution.

from jax import jit

# Non JIT-compiled function
def square(x):
  return x * x 

# JIT-compiled function
@jit 
def jit_square(x):
  return x * x

The square function shown above is executed by the Python interpreter and is slow to run.

On the other hand, the jit_square function that uses the @jit decorator runs extremely fast due to JIT compilation.

When the function is first called, the JIT engine:

  • Traces the function and builds an optimized computation graph

  • The graph is compiled to optimized XLA (low-level) code.

  • The result is cached.

  • The cached version of the function is used for subsequent calls.

Keep reading with a 7-day free trial

Subscribe to Into AI to keep reading this post and get 7 days of free access to the full post archives.

Already a paid subscriber? Sign in
© 2025 Dr. Ashish Bamania
Privacy ∙ Terms ∙ Collection notice
Start your SubstackGet the app
Substack is the home for great culture