In this course, we will use a few functionalities of JAX. These functionalities are:
jax.numpy (as jnp)
jax.jacrev and jax.jacfwd
jax.jit
partial (from functools)
defining a if-else condition in a function
jax.numpy (as jnp)
JAX provides a numpy-like interface called jax.numpy. It works almost the same as numpy, but supports automatic differentiation.
import jax.numpy as jnp# Example: compute strain energy for a 1D springk =10.0u =0.2energy =0.5* k * u**2print("Spring energy =", energy)# Using jnpu_jax = jnp.array(0.2)energy_jax =0.5* k * u_jax**2print("Spring energy (jax) =", energy_jax)
Spring energy = 0.20000000000000004
Spring energy (jax) = 0.20000002
Updating and Accessing Values in jax.numpy
JAX arrays are immutable: you cannot directly change an entry.
Instead, you use the .at operator, which returns a new array with modifications.
import jax.numpy as jnpx = jnp.array([10, 20, 30, 40])# Change the value at index 2 (third entry) to 99x_new = x.at[2].set(99)print("Original:", x) # [10 20 30 40]print("Updated :", x_new) # [10 20 99 40]
Original: [10 20 30 40]
Updated : [10 20 99 40]
Incrementing values with .add()
y = jnp.array([1.0, 2.0, 3.0, 4.0])# Add 5 to element at index 1y_new = y.at[1].add(5.0)print("Original:", y) # [1. 2. 3. 4.]print("Updated :", y_new) # [1. 7. 3. 4.]
jax.jacrev → reverse-mode differentiation (good for many inputs, few outputs).
jax.jacfwd → forward-mode differentiation (good for few inputs, many outputs).
In mechanics:
Internal force = derivative of energy w.r.t. displacement.
Stiffness matrix = derivative of force w.r.t. displacement.
Using this functions, gives another function as output which takes the same arguments as the original function but now returns the derivative of the original function with respect to the argument.
import jaximport jax.numpy as jnp# 2D linear spring systemdef energy(u): k = jnp.array([[10.0, 0.0], [0.0, 5.0]])return0.5* u @ k @ u# Internal force = ∂E/∂ucompute_internal_force = jax.jacrev(energy)internal_force = compute_internal_force(jnp.array([0.1, 0.2]))print("Internal force =", internal_force)# Stiffness = ∂f/∂ucompute_stiffness = jax.jacfwd(compute_internal_force)stiffness = compute_stiffness(jnp.array([0.1, 0.2]))print("Stiffness matrix =\n", stiffness)
jax.jit compiles a function so it runs faster. It is useful when the same computation is repeated (e.g., in FEM assembly or Newton iterations).
@jax.jitdef compute_energy(u): k =10.0return0.5* k * u**2print(compute_energy(0.5))
1.25
jax.jvp: Jacobian-Vector Product
For Matrix-free solvers, we will see that we do not need to materialize the stiffness matrix. Rather we need just the action of matrix on a vector. This is called Jacobian-Vector Product or JVP. The JVP is defined as $
The two arguments the jax.jvp functions returns are:
value of the function that was given to differentiation \(\boldsymbol{f}(\boldsymbol{v})\)
the action of \(\mathbf{K}\) on \(\boldsymbol{v}\)
functools.partial: fix arguments of functions
We often want to “freeze” some parameters of a function (e.g., material constants) so that JAX only sees displacement as the variable.
from functools import partialdef spring_energy(u, k):return0.5* k * u**2# Fix spring stiffness k=10energy_fixed_k = partial(spring_energy, k=10.0)force = jax.grad(energy_fixed_k)(0.3)print("Force with k=10 =", force)
Force with k=10 = 3.0
Putting it all together
import jaximport jax.numpy as jnpfrom functools import partial@jax.jitdef bar_energy(u, k, L):"""Energy of a bar: E = 0.5*k*(u/L)^2 * L""" strain = u / Lreturn0.5* k * strain**2* L# Fix k and Lbar_energy_fixed = partial(bar_energy, k=100.0, L=1.0)# Internal forcecompute_internal_force = jax.jacrev(bar_energy_fixed)force = compute_internal_force(0.1)# Stiffnesscompute_stiffness = jax.jacfwd(compute_internal_force)stiffness = compute_stiffness(0.1)print("Force =", force)print("Stiffness =", stiffness)
Force = 10.0
Stiffness = 100.0
Vectorize a function
Tip
@autovmap: Automatically vectorize a function giving the expected shape each argument.
During the course, you will notice us using a function decorator @autovmap. This utility function from a tiny package named jax_autovmap eases the process of vectorizing a function. The term vectorization is used to describe the process of converting a function that operates on a single input to a function that operates on an array of inputs.
from jax_autovmap import autovmapfrom jax import Array
The autovmap function works as a python decorator that wraps a function and vectorizes it. The autovmap function takes input as the dimension of the variables passed to the function. The different dimensions of an input variable are defined as follows:
0: scalar (for example, material properties)
1: vector (for example, a displacement field or any vector with one index)
2: tensor (for example, a strain tensor, stress tensor or any matrix with two indices)
For example, lets us assume we want to compute stress from strain given as
where \(\sigma\) is the stress tensor, \(\epsilon\) is the strain tensor, \(\mu\) is the shear modulus, \(\lambda\) is the Lame parameter, and \(\mathbb{I}\) is the identity tensor. The term \(\text{tr}(\epsilon)\) is the trace of the strain tensor.
Below is the function that computes stress from strain and given material properties.
@autovmap(eps=2, mu=0, _lambda=0)def compute_stress(eps: Array, mu: float, _lambda: float) -> Array:"""Compute the stress from the strain."""return2* mu * eps + _lambda * jnp.trace(eps) * jnp.eye(2)
We wrap the above function using the autovmap wrapper. And in the autovmap function, we pass the dimension of all 3 input variables. The strain is a 2nd order tensor, so we pass 2 to the autovmap function. And since the material properties are scalars, we pass 0 to the autovmap function.
Now lets execute the function with an array of strains. Below we create an array of 100 strain tensors. And execute the function with this array.
Upon execution it gives stresses as an array of shape (100, 2, 2). This is because the function is vectorized and the output is an array of shape (100, 2, 2). Notice that for \(\mu\) and \(\lambda\), we only provided a single value and not an array. The autovmap function automatically broadcasts the scalar values to the shape of the input array.
In case, we had different values of \(\mu\) and \(\lambda\) for each strain, we could have passed them as arrays. For example, below is the function that computes stress from strain and given material properties.
In JAX, you cannot use normal Python if/else inside @jit or inside functions you want to differentiate,
because JAX needs to trace the computation graph.
Instead, use jax.lax.cond(predicate, true_fun, false_fun, operand).
predicate: boolean value (or array of booleans)
true_fun: function to execute if predicate == True
false_fun: function to execute if predicate == False
operand: input passed to either function
Example: Simple If-Else
import jaximport jax.numpy as jnpfrom jax import laxdef choose_branch(x):def true_fun(x):return x *2def false_fun(x):return x -2return lax.cond( x >0, true_fun, false_fun, x )print(choose_branch(3.0)) # 6.0 (x > 0 branch)print(choose_branch(-1.0)) # -3.0 (x <= 0 branch)