JAXing Up Your Machine Learning
The reason why should you learn a new ML framework: speed
Ever felt like your deep learning code is moving at a snail's pace? You're not alone.
While frameworks like TensorFlow and PyTorch have revolutionized AI development, they often come with a certain rigidity and verbosity. Enter JAX, Google's high-performance numerical computing library that's shaking things up.
At least, according to the experts!
JAX is a toolkit built on NumPy, designed for high-performance numerical computation and machine learning research. Its core strength lies in:
Providing a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings.
Featuring built-in Just-In-Time (JIT) compilation via Open XLA, an open-source machine learning compiler ecosystem.
Functions that support efficient evaluation of gradients via its automatic differentiation transformations.
Functions that can be automatically vectorized to efficiently map them over arrays representing batches of inputs.
This unique approach leads to incredibly concise, flexible, and blazing-fast code.
If you're ready to supercharge your machine learning workflows, write less code, and gain unprecedented control over your models, then JAX is definitely worth a look.
Or… if you just want to learn a new framework because it's fun, which to be honest was my case, this blog might be for you!
Let's dive in :)
Installation
Following the official tutorial JAX can be installed in two ways. If you're environment is a Linux, Windows, or macOS
pip install jaxIf you have a fancy expensive NVIDIA GPU:
pip install -U "jax[cuda12]"Wait a minute is this NumPy?
Yeah, so basically JAX is very similar to NumPy in terms of syntax. It's usually imported using the jnp alias that's:
import jax.numpy as jnpThen we can use JAX in a similar manner as NumPy code. Almost all your familiar NumPy functions have a jnp counterpart: jnp.array, jnp.arange, jnp.zeros, jnp.mean, jnp.dot, and so on.
This immediately lowers the barrier to entry, as you're already familiar with the core operations.
Let's see this in action with a quick example. Imagine you want to perform a few basic array operations.
Creating arrays
a = jnp.array([1.0, 2.0, 3.0])
b = jnp.arange(3.0, 6.0) # From 3.0 to up to (but not including) 6.0
# a -> [1. 2. 3.]
# b -> [3. 4. 5.]Basic operations
# Element-wise addition
c = a + b
print("a + b:", c)
# Dot product
d = jnp.dot(a, b)
print("Dot product of a and b:", d)
# Reshaping
e = jnp.zeros((2, 3))
print("Zeros array, reshaped:", e)
# Transposing
f = e.T
print("Transposed array 'f':\n", f)As you can see NumPy and JAX are very similar. However, there's a crucial difference between them:
JAX arrays, unlike NumPy arrays, are immutable.
Let's first understand what I mean by immutability. Take for instance the following NumPy code:
np_arr = np.array([1.0, 2.0, 3.0])
print("Original array 'np_arr':\n", np_arr)
np_arr[0] = 99 # Modifying an element in-place
print("Modified array 'np_arr':\n", np_arr)Output:
Original array 'np_arr':
[1. 2. 3.]
Modified array 'np_arr':
[99. 2. 3.]Now, let's see what happens if we try the same with JAX:
jax_arr = jnp.array([1.0, 2.0, 3.0])
print("JAX array 'jax_arr':\n", jax_arr)
try:
jax_arr[0] = 99
except TypeError as e:
print(f"Attempting in-place modification raised an error: {e}")Output:
JAX array 'jax_arr':
[1. 2. 3.]
Attempting in-place modification raised an error: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.htmlAs you can see, attempting to modify jax_arr[0] directly results in a TypeError.
Instead, we use methods like jax_arr.at[<index>].set(<value>). This method returns a new array with the specified changes, leaving the original jax_arr completely untouched.
JAX's Superpowers: grad and jit
Now that we understand JAX's NumPy-like interface and its immutable nature, it's time to unleash the true power.
JAX isn't just a fast array library; it's a library for function transformations. Its two most fundamental and frequently used transformations are:
jax.grad →for automatic differentiation andjax.jit →for Just-In-Time compilation.
jax.grad
If you're doing anything with machine learning, you're constantly dealing with gradients.
Training neural networks, for instance, relies heavily on calculating gradients to update model parameters. Manually deriving gradients, especially for complex functions, is tedious, error-prone, and often impossible. This is where jax.grad shines.
jax.grad takes a Python function and returns a new function that computes the gradient of the original function with respect to its arguments.
Example 1:
Let's see one example for the function:
def square(x):
return x ** 2
# Get the gradient function for 'square'
gradient_square = jax.grad(square)
# Calculate the gradient at x = 3.0
# The derivative of x^2 is 2x. So at x = 3, it should be 6.
print(f"Original function 'square(3.0)': {square(3.0)}")
print(f"Gradient of square(x) at x=3.0: {gradient_square(3.0)}")Output:
Original function 'square(3.0)': 9.0
Gradient of square(x) at x=3.0: 6.0Example 2:
Now for a slightly more complex function:
def sin_square(x):
return jnp.sin(x ** 2)
# Get its gradient function
gradient_sin_square = jax.grad(sin_square)
# Calculate the gradient at x = 1.0
print(f"\nOriginal function 'sin_square(1.0)': {sin_square(1.0)}")
print(f"Gradient of sin_square(x) at x=1.0: {gradient_sin_square(1.0)}")
print(f"Expected: {2 * jnp.cos(1.0)}") # Compare with manual calculationOutput:
Original function 'sin_square(1.0)': 0.8414709568023682
Gradient of sin_square(x) at x=1.0: 1.0806045532226562
Expected: 1.0806045532226562Example 3:
One final example, for a function with multiple arguments:
def sum_of_squares(x, y):
return x**2 + y**2
# By default, jax.grad computes the gradient with respect to the first argument.
grad_x_sum_of_squares = jax.grad(sum_of_squares)
print(f"\nGradient of sum_of_squares w.r.t. x at (x=2, y=3): {grad_x_sum_of_squares(2.0, 3.0)}") # Should be 2*2 = 4
# To specify which arguments to differentiate with respect to, use 'argnums'.
grad_xy_sum_of_squares = jax.grad(sum_of_squares, argnums=(0, 1))
grad_x, grad_y = grad_xy_sum_of_squares(2.0, 3.0)
print(f"Gradient of sum_of_squares w.r.t. x and y at (x=2, y=3): x_grad={grad_x}, y_grad={grad_y}") # Should be 4, 6Output:
Gradient of sum_of_squares w.r.t. x at (x=2, y=3): 4.0
Gradient of sum_of_squares w.r.t. x and y at (x=2, y=3): x_grad=4.0, y_grad=6.0The beauty of jax.grad is that it works seamlessly with complex compositions of jnp operations and even control flow (like if/else statements and loops), as long as they are "pure" JAX functions (no side effects, relying on immutable inputs).
For a more in depth explanation of pure JAX functions read the next section.
This makes it an indispensable tool for optimization, neural network training, and any scenario requiring derivatives.
jax.jit
While JAX's operations are already fast, truly unlocking their potential on GPUs or TPUs (and even CPUs for certain workloads) requires compilation.
This is where jax.jit comes in. It's a decorator or function that takes a Python function and compiles it into highly optimized machine code, typically using XLA (Accelerated Linear Algebra).
The "Just-In-Time" part means the compilation happens the first time the function is called with specific input shapes and types. Subsequent calls with the same shapes and types will then run the compiled, much faster version.
Example:
We'll calculate the sum of an intensive function:
Where N will be 1 million.
We'll write a code comparing a raw execution, the first execution with jit and subsequent execution using jit:
# A computationally intensive function
def sum_large_array(x):
return jnp.sum(x * jnp.sin(x) / jnp.cosh(x) + jnp.log(x + 1))
# Create a large JAX array
large_array = jnp.arange(1, 1_00_001, dtype=jnp.float32) # From 1 to 1 million
# Test without JIT
start_time = time.time()
result_non_jit = sum_large_array(large_array)
end_time = time.time()
print(f"Non-JIT execution time: {end_time - start_time:.4f} seconds")
# Test with JIT
# Apply jax.jit as a decorator or directly wrap the function
@jax.jit
def sum_large_array_jitted(x):
return jnp.sum(x * jnp.sin(x) / jnp.cosh(x) + jnp.log(x + 1))
# First call (compilation happens here)
start_time = time.time()
result_jit_first_call = sum_large_array_jitted(large_array)
end_time = time.time()
print(f"JIT first call (with compilation) time: {end_time - start_time:.4f} seconds")
# Subsequent calls (using compiled code)
start_time = time.time()
result_jit_subsequent_call = sum_large_array_jitted(large_array)
end_time = time.time()
print(f"JIT subsequent call time: {end_time - start_time:.4f} seconds")
# Verify results are the same
print(f"Results match: {jnp.allclose(result_non_jit, result_jit_first_call)}")Output:
Non-JIT execution time: 0.2244 seconds
JIT first call (with compilation) time: 0.0381 seconds
JIT subsequent call time: 0.0003 seconds
Results match: TrueAs you can see we had a dramatic speedup after the first JIT call. The initial call includes the compilation overhead, but subsequent calls leverage the highly optimized compiled code, often leading to orders of magnitude faster execution, especially on accelerators like GPUs.
The synergy between
jax.gradandjax.jitis truly powerful.
You can define a function, get its gradient, and then JIT-compile the gradient function for maximum performance, all with just a few lines of code. This is the core of how JAX enables high-performance machine learning research and production.
Pure Functions
Before we dive deeper into JAX's other transformations, it's essential to understand the concept of pure functions.
In JAX, everything is about transforming functions. For these transformations (like grad and jit) to work reliably and efficiently, the functions you pass to them should ideally be pure.
So, what makes a function "pure"? A pure function has two main characteristics:
Deterministic Output:
Given the same input arguments, a pure function will always produce the exact same output.
It does not depend on any hidden state, global variables, or external factors that might change between calls.
No Side Effects:
A pure function does not cause any observable changes outside its own scope. This means it doesn't:
Modify its input arguments (this is where JAX's array immutability comes in!).
Modify global variables.
Perform I/O operations (like printing to console, reading from files, network requests).
Mutate data structures in place.
Let's illustrate the difference.
Example of an impure function:
global_counter = 0
def impure_add_and_increment(x):
global global_counter
global_counter += 1 # Side effect: modifies a global variable
# This print statement shows the internal state *during* the function call
print(f" (Inside function) global_counter after increment: {global_counter}")
return x + global_counter
print("--- Impure Function Example ---")
print(f"Initial global_counter: {global_counter}")
# First call to the impure function
result_1 = impure_add_and_increment(5)
print(f"First call impure_add_and_increment(5) returned: {result_1}")
print(f"global_counter after first call: {global_counter}") # Showing the side effect
# Second call to the impure function
result_2 = impure_add_and_increment(5)
print(f"Second call impure_add_and_increment(5) returned: {result_2}")
print(f"global_counter after second call: {global_counter}\n") # Showing the side effect again
print(f"Final value of global_counter: {global_counter}") # Final check of the modified global stateOutput:
--- Impure Function Example ---
Initial global_counter: 0
(Inside function) global_counter after increment: 1
First call impure_add_and_increment(5) returned: 6
global_counter after first call: 1
(Inside function) global_counter after increment: 2
Second call impure_add_and_increment(5) returned: 7
global_counter after second call: 2
Final value of global_counter: 2So the function above is considered impure because it violates both conditions of a pure function:
Not Deterministic: Its output depends on
global_counter, which is a mutable global variable outside the function's parameters. Callingimpure_add_and_increment(5)twice yields6then7becauseglobal_counterchanged between calls, demonstrating non-determinism.Has Side Effects: It directly modifies
global_counter(a variable outside its local scope) and performs a print operation (I/O). These are observable changes to the program's state or environment that persist after the function finishes, proving it has side effects.
Example of a pure function
def pure_add(x, y):
return x + y # Deterministic, no side effects
def pure_transform_list(my_list):
# This creates a NEW list, it doesn't modify my_list in-place
return [item * 2 for item in my_list]
print("\n--- Pure Function Example (for contrast) ---")
print(f"pure_add(5, 3) returned: {pure_add(5, 3)}") # Always 8
print(f"pure_add(5, 3) returned: {pure_add(5, 3)}") # Still 8
print("Notice how pure_add always gives the same result for the same inputs and causes no external changes.")
original_list = [1, 2, 3]
new_list = pure_transform_list(original_list) # Assuming pure_transform_list is defined from previous snippet
print("\nOriginal list (unchanged after pure transformation):", original_list)
print("New list (transformed by pure function):", new_list)Output:
--- Pure Function Example ---
pure_add(5, 3) returned: 8
pure_add(5, 3) returned: 8
Notice how pure_add always gives the same result for the same inputs and causes no external changes.
Original list (unchanged after pure transformation): [1, 2, 3]
New list (transformed by pure function): [2, 4, 6]As you can see from the output above, the functions pure_add and pure_transform_list are pure. But why is that?
Deterministic Output:
pure_add(x, y)always returnsx + y.Given
(5, 3), it will always return8, every single time, regardless of any external state or previous calls.
pure_transform_list(my_list)always produces a new list where each element is doubled.For
[1, 2, 3], it always returns[2, 4, 6].
No Side Effects:
pure_addonly uses its input arguments to compute a result and returns that result.It doesn't modify any global variables, perform I/O, or change its inputs.
pure_transform_listcrucially creates a new list ([item * 2 for item in my_list]) instead of modifyingmy_listin place.This ensures the original
original_listremains unchanged, demonstrating no side effects on its inputs or external state.
More JAX Transformations: vmap and pmap
Beyond grad for differentiation and jit for compilation, JAX provides two more essential transformations that dramatically simplify writing efficient and scalable code:
jax.vmapfor vectorizing functions andjax.pmapfor parallelizing computations across multiple devices.
jax.vmap
Often, you write a function that operates on a single data point (e.g., transforming a single vector or applying an operation to one image). When you need to apply that same function to a batch of data points, you'd typically write a loop or manually reshape your data and use broadcasting. This can be cumbersome and error-prone.
jax.vmap (vectorized map) solves this elegantly. It takes a function that operates on a single example and transforms it into a new function that operates on a batch of examples.
It automatically handles the batching and unbatching, effectively "vectorizing" the code without you having to manually manage array dimensions.
Example 1:
Imagine we have the following function:
# A function that operates on a single vector
def elementwise_multiply_add(x, y):
# Imagine this is a complex operation on single data points
return x * 2 + y / 3
# Define individual inputs
x_single = jnp.array([1.0, 2.0, 3.0])
y_single = jnp.array([4.0, 5.0, 6.0])
print(f"Result for single inputs: {elementwise_multiply_add(x_single, y_single)}")Output:
Result for single inputs: [3.3333335 5.6666665 8. ]Now, imagine that instead of having individual data inputs we have batch data:
# x_batch: batch of 2 vectors, each of size 3
x_batch = jnp.array([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
# y_batch: batch of 2 vectors, each of size 3
y_batch = jnp.array([[10.0, 11.0, 12.0],
[13.0, 14.0, 15.0]])We could apply this by simply defining a for loop which we can call the manual way of making this computation:
results = []
for i in range(x_batch.shape[0]):
results.append(elementwise_multiply_add(x_batch[i], y_batch[i]))
manual_batch_result = jnp.array(results)
print(f"Manual batch result: {manual_batch_result}")Output:
Manual batch result: [[ 5.333333 7.666667 10. ]
[12.333334 14.666666 17. ]]And that's where jax.pmap comes in handy. Instead of making those computations manually, we can simply:
batched_elementwise_multiply_add = jax.vmap(elementwise_multiply_add)
vmap_result = batched_elementwise_multiply_add(x_batch, y_batch)
print(f"vmap-batched result: {vmap_result}")Output:
vmap-batched result: [[ 5.333333 7.666667 10. ]
[12.333334 14.666666 17. ]]Example 2:
jax.vmap can also be combined with jax.grad and jax.jit:
@jax.jit
@jax.vmap
def batched_loss_gradient(predictions, targets):
# This function operates on a single prediction-target pair
loss_fn = lambda p, t: jnp.mean((p - t)**2) # MSE Loss
# Gradient of loss w.r.t predictions
return jax.grad(loss_fn)(predictions, targets) Applying this function to a batch data yields:
# Batch data
preds_batch = jnp.array([[1.0, 2.0], [3.0, 4.0]])
targets_batch = jnp.array([[1.1, 2.1], [3.3, 4.5]])
batched_grads = batched_loss_gradient(preds_batch, targets_batch)
print(f"\nBatched gradients using vmap and grad: \n{batched_grads}")Output:
Batched gradients using vmap and grad:
[[-0.10000002 -0.0999999 ]
[-0.29999995 -0.5 ]]jax.pmap
When you have access to multiple accelerators (like multiple GPUs in a single machine or even across multiple machines in a cluster), jax.pmap (parallel map) allows you to automatically parallelize computations across these devices.
pmap is similar to vmap in that it maps a function over a batch of inputs, but it specifically distributes these computations across available devices.
It's designed for data parallelism, where each device processes a slice of the input data independently, and then the results are often aggregated (e.g., for model averaging or gradient synchronization).
To use pmap, your data needs to be sharded (split) across devices. pmap functions operate on arrays where the first axis is considered the "batch" axis that will be split across devices.
Example:
# Check available devices
print(f"\nAvailable devices: {jax.devices()}")
num_devices = len(jax.devices())
if num_devices < 2:
print("\nSkipping pmap example: Requires at least 2 devices (e.g., multiple GPUs or CPU devices for simulation).")
else:
# A simple function that computes mean
def device_mean(x):
return jnp.mean(x)
# Let's create an array that we want to parallelize
data = jnp.arange(16.0).reshape(num_devices, -1)
print(f"\nData for pmap (sharded across devices):\n{data}")
# Use pmap to run device_mean on each slice of data on each device
pmapped_mean = jax.pmap(device_mean, axis_name='devices')
# Each device gets a slice of 'data' (e.g., data[0] on device 0, data[1] on device 1)
results_per_device = pmapped_mean(data)
print(f"Mean calculated on each device:\n{results_per_device}")
# Often, you'll want to combine results from all devices.
def sum_across_devices(x):
# x here is the local part on each device
return jax.lax.psum(x, axis_name='devices')
# pmap this new function, using the same axis_name
pmapped_sum_across = jax.pmap(sum_across_devices, axis_name='devices')
# Each device computes its local sum, then sums these across devices
total_sum = pmapped_sum_across(data)
print(f"Total sum across all devices (collective operation):\n{total_sum[0]}")
print(f"Verified total sum: {jnp.sum(data)}")Ouput:
Sadly the example above will only work if you have a GPU. If you don't just skip it for now!
References
Code
✨ The End
If you’ve read this far…
First of all, thank you! 🙏
I hope this post helped clarify a few things and made your Deep Learning journey a bit easier.
More deep dives on JAX are on the way, so stay curious and keep building!
“But the answer isn’t always to do more. Problems can’t always be solved by running faster. Sometimes the simplest solution is the best.”
— Barry Allen “The Flash”
Until next time,
Igor L.R. Azevedo


