Appendix B — Basics of JAX

In this course, we will use a few functionalities of JAX. These functionalities are:

jax.numpy (as jnp)

JAX provides a numpy-like interface called jax.numpy. It works almost the same as numpy, but supports automatic differentiation.

import jax.numpy as jnp

# Example: compute strain energy for a 1D spring
k = 10.0
u = 0.2
energy = 0.5 * k * u**2
print("Spring energy =", energy)

# Using jnp
u_jax = jnp.array(0.2)
energy_jax = 0.5 * k * u_jax**2
print("Spring energy (jax) =", energy_jax)
Spring energy = 0.20000000000000004
Spring energy (jax) = 0.20000002

Updating and Accessing Values in jax.numpy

JAX arrays are immutable: you cannot directly change an entry.
Instead, you use the .at operator, which returns a new array with modifications.

import jax.numpy as jnp

x = jnp.array([10, 20, 30, 40])

# Change the value at index 2 (third entry) to 99
x_new = x.at[2].set(99)

print("Original:", x)     # [10 20 30 40]
print("Updated :", x_new) # [10 20 99 40]
Original: [10 20 30 40]
Updated : [10 20 99 40]

Incrementing values with .add()

y = jnp.array([1.0, 2.0, 3.0, 4.0])

# Add 5 to element at index 1
y_new = y.at[1].add(5.0)

print("Original:", y)     # [1. 2. 3. 4.]
print("Updated :", y_new) # [1. 7. 3. 4.]
Original: [1. 2. 3. 4.]
Updated : [1. 7. 3. 4.]

Updating multiple values

w = jnp.array([0, 1, 2, 3, 4, 5])

indices = jnp.array([1, 3, 5])

# Update multiple indices [1, 3, 5] to -1
w_new = w.at[indices].set(-1)

print("Original:", w)     # [0 1 2 3 4 5]
print("Updated :", w_new) # [ 0 -1  2 -1  4 -1]
Original: [0 1 2 3 4 5]
Updated : [ 0 -1  2 -1  4 -1]

jax.jacrev and jax.jacfwd: Differentiation

jax.jacrev → reverse-mode differentiation (good for many inputs, few outputs).

jax.jacfwd → forward-mode differentiation (good for few inputs, many outputs).

In mechanics:

Internal force = derivative of energy w.r.t. displacement.

Stiffness matrix = derivative of force w.r.t. displacement.

Using this functions, gives another function as output which takes the same arguments as the original function but now returns the derivative of the original function with respect to the argument.

import jax
import jax.numpy as jnp

# 2D linear spring system
def energy(u):
    k = jnp.array([[10.0, 0.0],
                   [0.0, 5.0]])
    return 0.5 * u @ k @ u

# Internal force = ∂E/∂u
compute_internal_force = jax.jacrev(energy)

internal_force = compute_internal_force(jnp.array([0.1, 0.2]))
print("Internal force =", internal_force)

# Stiffness = ∂f/∂u
compute_stiffness = jax.jacfwd(compute_internal_force)
stiffness = compute_stiffness(jnp.array([0.1, 0.2]))
print("Stiffness matrix =\n", stiffness)
Internal force = [1. 1.]
Stiffness matrix =
 [[10.  0.]
 [ 0.  5.]]

jax.jit: Just-in-Time compilation

jax.jit compiles a function so it runs faster. It is useful when the same computation is repeated (e.g., in FEM assembly or Newton iterations).

@jax.jit
def compute_energy(u):
    k = 10.0
    return 0.5 * k * u**2

print(compute_energy(0.5))
1.25

jax.jvp: Jacobian-Vector Product

For Matrix-free solvers, we will see that we do not need to materialize the stiffness matrix. Rather we need just the action of matrix on a vector. This is called Jacobian-Vector Product or JVP. The JVP is defined as $

\[ \mathbf{K}\boldsymbol{v} = \frac{\partial \boldsymbol{f}_\text{int}}{\partial \boldsymbol{u}}\boldsymbol{v} \]

To compute the JVP, we will use jax.jvp which takes the following arguments:

  • function to differentiate i.e \(\boldsymbol{f}_\text{int}\)
  • values with which we want to differentiate the function i.e \(\boldsymbol{u}\)
  • values on which we want to take the action i.e \(\boldsymbol{v}\)
compute_internal_force = jax.jacrev(energy)

u = jnp.array([1.1, 0.2])
v = jnp.array([0.01, 0.02])

f_jvp, delta_f_jvp = jax.jvp(compute_internal_force, (u,), (v,))

print(f_jvp)
print(delta_f_jvp)

The two arguments the jax.jvp functions returns are:

  • value of the function that was given to differentiation \(\boldsymbol{f}(\boldsymbol{v})\)
  • the action of \(\mathbf{K}\) on \(\boldsymbol{v}\)

functools.partial: fix arguments of functions

We often want to “freeze” some parameters of a function (e.g., material constants) so that JAX only sees displacement as the variable.

from functools import partial

def spring_energy(u, k):
    return 0.5 * k * u**2

# Fix spring stiffness k=10
energy_fixed_k = partial(spring_energy, k=10.0)

force = jax.grad(energy_fixed_k)(0.3)
print("Force with k=10 =", force)
Force with k=10 = 3.0

Putting it all together

import jax
import jax.numpy as jnp
from functools import partial

@jax.jit
def bar_energy(u, k, L):
    """Energy of a bar: E = 0.5*k*(u/L)^2 * L"""
    strain = u / L
    return 0.5 * k * strain**2 * L

# Fix k and L
bar_energy_fixed = partial(bar_energy, k=100.0, L=1.0)

# Internal force
compute_internal_force = jax.jacrev(bar_energy_fixed)
force = compute_internal_force(0.1)
# Stiffness
compute_stiffness = jax.jacfwd(compute_internal_force)
stiffness = compute_stiffness(0.1)

print("Force =", force)
print("Stiffness =", stiffness)
Force = 10.0
Stiffness = 100.0

Vectorize a function

Tip

@autovmap: Automatically vectorize a function giving the expected shape each argument.

During the course, you will notice us using a function decorator @autovmap. This utility function from a tiny package named jax_autovmap eases the process of vectorizing a function. The term vectorization is used to describe the process of converting a function that operates on a single input to a function that operates on an array of inputs.

from jax_autovmap import autovmap
from jax import Array

The autovmap function works as a python decorator that wraps a function and vectorizes it. The autovmap function takes input as the dimension of the variables passed to the function. The different dimensions of an input variable are defined as follows:

  • 0: scalar (for example, material properties)
  • 1: vector (for example, a displacement field or any vector with one index)
  • 2: tensor (for example, a strain tensor, stress tensor or any matrix with two indices)

For example, lets us assume we want to compute stress from strain given as

\[ \sigma = 2\mu \epsilon + \lambda \text{tr}(\epsilon) \mathbb{I} \]

where \(\sigma\) is the stress tensor, \(\epsilon\) is the strain tensor, \(\mu\) is the shear modulus, \(\lambda\) is the Lame parameter, and \(\mathbb{I}\) is the identity tensor. The term \(\text{tr}(\epsilon)\) is the trace of the strain tensor.

Below is the function that computes stress from strain and given material properties.

@autovmap(eps=2, mu=0, _lambda=0)
def compute_stress(eps: Array, mu: float, _lambda: float) -> Array:
    """Compute the stress from the strain."""
    return 2 * mu * eps + _lambda * jnp.trace(eps) * jnp.eye(2)

We wrap the above function using the autovmap wrapper. And in the autovmap function, we pass the dimension of all 3 input variables. The strain is a 2nd order tensor, so we pass 2 to the autovmap function. And since the material properties are scalars, we pass 0 to the autovmap function.

Now lets execute the function with an array of strains. Below we create an array of 100 strain tensors. And execute the function with this array.

strains = []
for i in range(100):
    strains.append(jnp.eye(2)*i)

strains = jnp.array(strains)

stresses =compute_stress(strains, 1.0, 0.0)

print(stresses.shape)
(100, 2, 2)

Upon execution it gives stresses as an array of shape (100, 2, 2). This is because the function is vectorized and the output is an array of shape (100, 2, 2). Notice that for \(\mu\) and \(\lambda\), we only provided a single value and not an array. The autovmap function automatically broadcasts the scalar values to the shape of the input array.

In case, we had different values of \(\mu\) and \(\lambda\) for each strain, we could have passed them as arrays. For example, below is the function that computes stress from strain and given material properties.

stresses = compute_stress(
    strains, mu=jnp.linspace(0.0, 1.0, 100), _lambda=jnp.linspace(10.0, 100.0, 100)
)

print(stresses.shape)
(100, 2, 2)

Conditional Execution with jax.lax.cond

In JAX, you cannot use normal Python if/else inside @jit or inside functions you want to differentiate,
because JAX needs to trace the computation graph.

Instead, use jax.lax.cond(predicate, true_fun, false_fun, operand).

  • predicate: boolean value (or array of booleans)
  • true_fun: function to execute if predicate == True
  • false_fun: function to execute if predicate == False
  • operand: input passed to either function

Example: Simple If-Else

import jax
import jax.numpy as jnp
from jax import lax

def choose_branch(x):
    def true_fun(x):
        return x * 2
    def false_fun(x):
        return x - 2
    
    return lax.cond(
        x > 0,
        true_fun,
        false_fun,
        x
    )

print(choose_branch(3.0))   # 6.0  (x > 0 branch)
print(choose_branch(-1.0))  # -3.0 (x <= 0 branch)
6.0
-3.0