In-class activity: Node-to-Surface Contact

In this exercise, we will solve numerically the problem of contact between two elastic bodies. Unlike the last exercise, where we studied the contact between a deformable body and a rigid body, here we need to detect which part of the two bodies are coming in contact. Therefore, in order to correctly account for the contact between the bodies we will first, implement a contact detection algorithm to capture which part of the upper surface comes in contact with the lower body and then use the gap function to compute the contact energy using the Penalty method.

An animation of the contact between two elastic bodies is shown below.

Therefore, for this in-class activity we have two objectives:

Similar to last exercise, we will use the normal gap function between the surfaces of the two bodies to determine when the two bodies come in contact

\[g_n = (\boldsymbol{r} - \boldsymbol{\rho})\cdot{} \boldsymbol{n}\]

In this activity, we will use the node-to-surface approach to compute the above gap function (see Fundamentals).

Importing essential libraries
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_platforms", "cpu")
import jax.numpy as jnp
from jax import Array


from tatva import Mesh, Operator, element
from tatva.plotting import (
    STYLE_PATH,
    colors,
    plot_element_values,
    plot_nodal_values,
)

from typing import Callable, Optional, Tuple
from jax_autovmap import autovmap
from functools import partial

import numpy as np
import matplotlib.pyplot as plt

Node-to-surface contact detection

Doing node-to-surface detection means we have to perform the following steps for each node on the contactor surface:

  1. Construct the outward normal \(\boldsymbol{n}\) of every element on the opposing target surface.
  2. Calculate the projection point \(\boldsymbol\rho\) of every element on the opposing target surface.
  3. Calculate the distance between the contactor node and the projection point for every element on the opposing target surface.
  4. Filter the results to find the best valid point, the corresponding target element, and the normal at the contact point.
  5. Calculate the normal gap.

Step 1: Find the outward normal

Implement a function that returns the outward normal vector for one element of the target surface. Use @autovmap to automatically vectorize the function for an array of line segments.

Hint

The function implement will look something like this:


@autovmap(surface_points=2, reference_point=1)
def find_normal_to_surface(surface_points : Array, reference_point : Array) -> Array:
    """
    Find the normal to the surface at a point.

    Args:
        surface points: Coordinates of the target surface element, shape (2, 2)
        reference_point: Coordinates of a known reference point which is inside the surface/body, shape (2,)
    
    Returns:
        normal: The outward normal to the surface, shape (2,)
    """

    # TODO: Implement the function

    return ...

Use jnp.linalg.norm to normalize the normal vector,jnp.vdot to compute the dot product and jnp.where to flip the sign of the normal vector if it points away from the reference point.

Code: Find the outward normal to a surface
@autovmap(surface_points=2, reference_point=1)
def find_normal_to_surface(surface_points: Array, reference_point: Array) -> Array:
    """
    Find the normal to the surface at a point.

    Args:
        surface_points: The coordinates of the surface points.
        reference_point: The coordinates of the reference point which is inside the surface.
    Returns:
        normal: The normal to the surface at the reference point.
    """
    tangent_vector = surface_points[1] - surface_points[0]
    normal = jnp.array([-tangent_vector[1], tangent_vector[0]])
    normal = normal / jnp.linalg.norm(normal)

    reference_vector = reference_point - surface_points[0]
    normal_sign = jnp.vdot(reference_vector, normal)

    normal = jnp.where(normal_sign > 0, -normal, normal)
    return normal

Step 2: Calculate the projection point

Implement a function that returns the projection point \(\boldsymbol\rho\) for one contactor-point/target-surface-element pair.

Hints

The function implementation will be something like this:


@autovmap(point=1, surface_coords=2, surface_normal=1)
def find_projection_point():
    """
    Find the projection point to a plane from a point.

    Args:
        point: The contactor point to find the orthogonal point to, shape (2,)
        surface_coords: The points on the surface, shape (2, 2)
        surface_normal: The outward normal to the surface element, shape (2,)
    Returns:
        projection_point: The projection point to the surface, shape (2,)
    """

    return ...
@autovmap(point=1, surface_coords=2, surface_normal=1)
def find_projection_point(
    point: Array, surface_coords: Array, surface_normal: Array
) -> float:
    """
    Find the orthogonal point to a plane from a point.
    """
    numerator = jnp.vdot(point - surface_coords[0], surface_normal)
    normal = jnp.linalg.norm(surface_normal)
    projection_point = point - numerator * surface_normal / normal**2
    return projection_point

Step 3: Calculate Distance and Check Bounds

In this function, we will perform two critical operations:

  • Check if the projection from find_projection_point is inside the surface.
  • Calculate the signed distance to the point.
Hints

The function implementation will be something like this:

@autovmap(projection_point=1, surface_coords=2, surface_normal=1)
def find_distance_to_point(
    point: Array, 
    projection_point: Array, 
    surface_coords: Array, 
    surface_normal: Array
) -> Tuple[float, float]:
    """
    Find the distance between a point and a surface element.

    Args:
        point: The point to find the distance to, shape (2,)
        projection_point: The projection point on the surface, shape (2,)
        surface_coords: The coordinates of the surface, shape (2, 2)  
        surface_normal: The normal to the surface, shape (2,)
    Returns:
        projected_distance: The projected distance to the point, shape (2,)
        inside: a boolean which is 1 if inside otherwise 0
    """
    # TODO: Check if the projection point is inside the surface using jnp.where(...)
    ...

    # TODO: return projected distance and the inside flag

    return ...
Code: Find the distance to a point from a surface
@autovmap(projection_point=1, surface_coords=2, surface_normal=1)
def find_distance_to_point(
    point: Array, projection_point: Array, surface_coords: Array, surface_normal: Array
) -> Tuple[float, float]:
    """
    Find the distance to a point from a surface.
    The function returns the projected distance and the distance to the point.

    Args:
        point: The point to find the distance to.
        coords: The coordinates of the surface.
        normal: The normal to the surface.
    Returns:
        projected_distance: The projected distance to the point.
        distance: The distance to the point.
    """
    distance_vector =point - projection_point  #orthogonal_point - point
    projected_distance = jnp.vdot(distance_vector, surface_normal)

    ap = projection_point - surface_coords[0]
    ab = surface_coords[1] - surface_coords[0]
    t = jnp.vdot(ap, ab) / jnp.linalg.norm(ab) ** 2

    inside = jnp.where(t < -0.01, 0, jnp.where(t > 1.01, 0, 1))

    return projected_distance, inside

Step 3.5: Testing the logic for a single contactor node and a single target segment

Test your implemented functions for the 3 test cases described in the class. We will use one of three contactor nodes and one target segment.

The target segment shall be the line segment connecting the points \((1, 0)\) and \((3, 0)\). The reference point inside the body to which the above line segment is a surface is \((0, -1)\)

We test the above implemented functions for 3 different contractor points:

  • \((2.0, 0.5)\)
  • \((3.5, 0.5)\)
  • \((2.0, -0.2)\)

Use the following code to test your implementation. Replace the contactor_point with the above points to test your implementation.

contactor_point = np.array([2, 0.5])  # Change this for different scenarios
surface_coords = np.array([[1, 0], [3, 0.0]])

reference_point = np.array([1, -1])

normal = find_normal_to_surface(surface_coords, reference_point)

projection_point = find_projection_point(
    contactor_point, surface_coords, normal
)

projected_distance, inside = find_distance_to_point(
    contactor_point, projection_point, surface_coords, normal
)

print("The normal to the surface is: ", normal)
print("The projection point is: ", projection_point)
print("The projected distance is: ", projected_distance)
print("The inside flag is: ", inside)

plt.style.use(STYLE_PATH)
plt.figure(figsize=(5, 3), layout="constrained")
ax = plt.axes()
ax.scatter(
    contactor_point[0],
    contactor_point[1],
    marker="x",
    s=50,
    color=colors.red,
    zorder=10,
    label="Contactor Point",
)
ax.scatter(
    reference_point[0],
    reference_point[1],
    marker="o",
    s=50,
    color="k",
    zorder=10,
    label="Reference Point",
)
ax.fill_between(x=np.linspace(0.9, 3.6), y1=0, y2=-1.5, color=colors.blue, alpha=0.5)
ax.scatter(
    projection_point[0],
    projection_point[1],
    marker="o",
    s=50,
    color=colors.blue,
    zorder=10,
    label="Projection Point",
)
ax.plot(surface_coords[:, 0], surface_coords[:, 1], lw=2, color="k", label="Surface")
ax.plot(
    [contactor_point[0], projection_point[0]],
    [contactor_point[1], projection_point[1]],
    color="gray",
)
ax.set_xlim(0.9, 3.6)
ax.set_ylim(-1.25, 0.8)
ax.legend(frameon=False)
plt.show()
The normal to the surface is:  [-0.  1.]
The projection point is:  [2. 0.]
The projected distance is:  0.5
The inside flag is:  1
Figure 13.1: Testing the implementation of contact detection for a single contactor point and a single target segment.

Step 4: Find the best valid point

This is the “search” step. For a single contactor node, we have a list of projected distances from all target segments. We need to filter this list and find the single “best” one (the closest and penetrating). Return the coords of the closest projection point and the normal of the corresponding segment.

Hint

The function implementation will look something like this

def find_valid_orthogonal_point(
    projected_distances: Array,
    inside: Array,
    projection_points: Array,
    surface_normals: Array,
) -> Tuple[Array, Array]:
    """
    Find the closest point and corresponding normal on a surface to a point.

    Args:
        projected_distances: The projected distances to the surface, shape (n,)
        inside: The inside flag, shape (n,) 
        projection_points: The projection points, shape (n, 2)
        surface_normals: The normals at the surface points, shape (n, 2)
    Returns:
        closest_point: The closest point on the surface, shape (2,)
        closest_normal: The normal at the closest point, shape (2,)
    """



    return ...

@jax.jit
def find_valid_orthogonal_point(
    projected_distances: Array,
    inside: Array,
    projection_points: Array,
    surface_normals: Array,
) -> Tuple[Array, Array]:
    """
    Find the closest point and corresponding normal on a surface to a point.

    Args:
        projected_distances: The projected distances to the surface, shape (n,)
        inside: The inside flag, shape (n,)
        projection_points: The projection points, shape (n, 2)
        surface_normals: The normals at the surface points, shape (n, 2)
    Returns:
        closest_point: The closest point on the surface, shape (2,)
        closest_normal: The normal at the closest point, shape (2,)
    """
    projected_distances = jnp.where(
        (inside == 1) & (projected_distances < 0), 
        projected_distances,
        -jnp.inf,
     )

    index = jnp.argmax(projected_distances)
    closest_point = projection_points[index]
    closest_normal = surface_normals[index]

    return closest_point, closest_normal

Step 5: Putting it all together

Finally, orchestrate all the steps in the main find_gap function. This function will look like it’s written for a single contactor point, but @autovmap will ensure it works for an entire array of points. This function will make use all previously defined functions.

Hint

The function implementation will look something like this

@autovmap(contactor_point=1)
def find_gap(contactor_point: Array, surface_points: Array, reference_points: Array) -> float:
    """
    Find the gap between a point and a surface.
    The function returns the gap.

    Args:
        contactor_point: The point to find the gap to.
        surface_points: The coordinates  of the nodes of the surface
        surface_normals: The normals at the nodes of the surface.

    Returns:
        gap: The closest gap between the point and the surface.
    """


    return ...

@autovmap(contactor_point=1)
def find_gap(contactor_point: Array, surface_points: Array, reference_points: Array) -> float:
    """
    Find the gap between a point and a surface.
    The function returns the gap.

    Args:
        contactor_point: The point to find the gap to.
        surface_points: The coordinates  of the nodes of the surface.
        surface_normals: The normals at the nodes of the surface.

    Returns:
        gap: The closest gap between the point and the surface.
    """

    surface_normals = find_normal_to_surface(surface_points, reference_points)

    projection_points = find_projection_point(
        contactor_point, surface_points, surface_normals
    )
    projected_distances, inside = find_distance_to_point(
        contactor_point, projection_points, surface_points, surface_normals
    )

    closest_point, closest_normal = find_valid_orthogonal_point(
        projected_distances,
        inside,
        projection_points=projection_points,
        surface_normals=surface_normals,
    )

    gap = jnp.vdot(contactor_point - closest_point, closest_normal)
    gap = jnp.where(jnp.min(projected_distances) == jnp.inf, 0.0, gap)
    return gap

Testing the logic for multiple contactor nodes and multiple target segments

Now we will test the implemented algorithm for multiple contractor nodes and multiple target segments. Let us assume two surfaces of two different bodies.

The upper surface is part of circular body and the lower surface is part of a rectangular body.

We will test two cases:

Case 1: When the upper surface is the contact surface and the lower surface is the target surface.

Case 2: When the lower surface is the contact surface and the upper surface is the target surface.

Code: Generate the contact surfaces for the test cases
def generate_contact_surfaces(
    num_upper_segments=5,
    num_lower_segments=7,
    radius=10.0,
    width=8.0,
    initial_gap=0.5
):
    """
    Generates nodes and elements for an upper circular surface and a lower flat surface.

    Args:
        num_upper_segments: Number of line elements for the circular surface.
        num_lower_segments: Number of line elements for the flat surface.
        radius: Radius of the circular body.
        width: Width of the flat rectangular body.
        initial_gap: Initial vertical separation between the two bodies.

    Returns:
        A tuple containing:
        - upper_nodes (np.ndarray): Coordinates of nodes on the circular surface.
        - upper_elements (np.ndarray): Connectivity for the circular surface elements.
        - lower_nodes (np.ndarray): Coordinates of nodes on the flat surface.
        - lower_elements (np.ndarray): Connectivity for the flat surface elements.
    """
    # Position the circle's center so its bottom is at y=initial_gap
    center_y = radius + initial_gap
    
    # Calculate angles to span a portion of the circle
    max_angle = np.arcsin((width / 2) / radius)
    angles = np.linspace(-max_angle, max_angle, num_upper_segments + 1)
    
    upper_x = radius * np.sin(angles)
    upper_y = center_y - radius * np.cos(angles)
    upper_nodes = np.vstack([upper_x, upper_y]).T

    # Create line elements [(0, 1), (1, 2), ...]
    upper_elements = np.array([[i, i + 1] for i in range(num_upper_segments)])

    # Generate the lower flat surface
    lower_x = np.linspace(-width / 2, width / 2, num_lower_segments + 1)
    lower_y = np.zeros_like(lower_x) # Position the surface at y=0
    lower_nodes = np.vstack([lower_x, lower_y]).T

    # Create line elements
    lower_elements = np.array([[i, i + 1] for i in range(num_lower_segments)])
    
    return jnp.array(upper_nodes), jnp.array(upper_elements), jnp.array(lower_nodes), jnp.array(lower_elements)

# Generate the geometry
upper_nodes, upper_elements, lower_nodes, lower_elements = generate_contact_surfaces()

plt.style.use(STYLE_PATH)
plt.figure(figsize=(6, 3), layout='constrained')
plt.plot(upper_nodes[:, 0], upper_nodes[:, 1], 'o-', color=colors.red, label='Upper Surface (Circular Body)')
plt.plot(lower_nodes[:, 0], lower_nodes[:, 1], 'o-', color=colors.green, label='Lower Surface (Rectangular Body)')
# Plot upper surface (circular)
plt.fill_between(upper_nodes[:, 0], y1=upper_nodes[:, 1], y2=2.5, color=colors.blue, alpha=0.5)
# Plot lower surface (flat)
plt.fill_between(lower_nodes[:, 0], y1=lower_nodes[:, 1], y2=-2.5, color=colors.blue, alpha=0.5)

plt.xlim(np.min(upper_nodes[:, 0]), np.max(upper_nodes[:, 0]))
plt.ylim(-2.5, 2.5)
plt.xlabel(r'$x$')
plt.ylabel(r'$y$')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.show()
Figure 13.2: Contact surfaces for the test cases for the contact detection algorithm

Below we define two functions to plot the contact pairs. The first function check_contact_pairs is used to check the contact pairs between the contactor and the target surface. The second function plot_contact_pairs is used to plot the contact pairs. These functions are there just for visualization purposes and are not used in the algorithm.

Functions to visualize the contact pairs
@autovmap(point=1)
def check_contact_pairs(
    point,
    surface_points,
    reference_points,
):
    surface_normals = find_normal_to_surface(surface_points, reference_points)

    #jax.debug.print("surface_normals: {}", surface_normals)

    projection_points = find_projection_point(
        point, surface_points, surface_normals
    )

    #jax.debug.print("orthogonal_points: {}", orthogonal_points)

    projected_distances, inside = find_distance_to_point(
        point, projection_points, surface_points, surface_normals
    )

    #jax.debug.print("projected_distances: {}", projected_distances)
    #jax.debug.print("inside: {}", inside)


    closest_point, closest_normal = find_valid_orthogonal_point(
        projected_distances,
        inside,
        projection_points=projection_points,
        surface_normals=surface_normals,
    )

    #jax.debug.print("point: {}", point)
    #jax.debug.print("closet point: {}", closest_point)

    index = jnp.where(inside == 1, size=inside.shape[0])[0][0]

    #jax.debug.print("index: {}", index)

    return closest_point, closest_normal, surface_points[index]

def plot_contact_pairs(contactor_points, target_surface_points, reference_points, ax=None):

    closest_points, closest_normals, surface_points = check_contact_pairs(
        point=contactor_points,
        surface_points=target_surface_points,
        reference_points=reference_points,
    )

    closest_points = closest_points.squeeze()

    if ax is None:
        plt.style.use(STYLE_PATH)
        plt.figure(figsize=(6, 3))
        ax = plt.axes()
    for i in range(contactor_points.shape[0]):
        ax.scatter(contactor_points[i, 0], contactor_points[i, 1], marker="x", s=50, color=colors.red)
        ax.scatter(closest_points[i, 0], closest_points[i, 1], marker="o")
        ax.text(1.05*contactor_points[i, 0], 1.05*contactor_points[i, 1], f"{i}", color='k')
        ax.text(1.05*closest_points[i, 0], 1.05*closest_points[i, 1], f"{i}", color='k')
        ax.plot(
            [contactor_points[i, 0], closest_points[i, 0]],
            [contactor_points[i, 1], closest_points[i, 1]],
            ls="--",
            color="gray",
        )
        x = [surface_points[i, 0, 0], surface_points[i, 1, 0]]
        y = [surface_points[i, 0, 1], surface_points[i, 1, 1]]
        ax.plot(x, y)
    ax.grid()

    if ax is None:
        plt.show()

Case 1: Upper surface is the contact surface and the lower surface is the target surface

Let us check if the algorithm works for the case when the upper surface is the contact surface and the lower surface is the target surface. Both the surfaces are separated i.e there is no contact between the two surfaces. This means that the gap should be positive everywhere. Since the target surface is flat (\(n\) is [0, 1]) at y=0, the gap will be the \(y-\)coordinate of the contactor point.

Use the find_gap function that you implemented above to compute the gap between the contactor and the target surface.

upper_nodes, upper_elements, lower_nodes, lower_elements = generate_contact_surfaces()
lower_reference_points = jnp.array([0, np.min(lower_nodes[:, 1]) - 100])
upper_reference_points = jnp.array([0, np.max(upper_nodes[:, 1]) + 5])

gap = find_gap(
    contactor_point=upper_nodes,
    surface_points=lower_nodes[lower_elements],
    reference_points=lower_reference_points,
)
print("gap: ", gap)
print("Difference: ", gap-upper_nodes[:, 1])
gap:  [1.33484861 0.80327753 0.53385011 0.53385011 0.80327753 1.33484861]
Difference:  [0. 0. 0. 0. 0. 0.]

Let us also visualize the contact pairs i.e the contactor points and the closest points on the target surface and the normals at the contact points. We use the plot_contact_pairs function to visualize the contact pairs.

plt.figure(figsize=(6, 4))
ax = plt.axes()
plot_contact_pairs(
    contactor_points=upper_nodes,
    target_surface_points=lower_nodes[lower_elements],
    reference_points=lower_reference_points,
    ax=ax,
)
ax.set_aspect("equal")
plt.show()
Figure 13.3: Contact pairs between the upper surface and the lower surface. Contactor points (upper surface) are marked with red crosses and the closest points on the target surface (lower surface) are marked with circles. The target surface is marked with the same color as the closest points.

Case 2: Lower surface is the contact surface and the upper surface is the target surface

Let us check if the algorithm works for the case when the lower surface is the contact surface and the upper surface is the target surface. Both the surfaces are separated i.e there is no contact between the two surfaces. This means that the gap should be positive everywhere. Since the target surface is flat (\(n\) is [0, 1]) at y=0, the gap will be the \(y-\)coordinate of the contactor point.

We also move the upper surface down by 2 units so that the two bodies are intersecting. Use the find_gap function that you implemented above to compute the gap between the contactor and the target surface.

upper_nodes, upper_elements, lower_nodes, lower_elements = generate_contact_surfaces()
lower_reference_points = jnp.array([0, np.min(lower_nodes[:, 1]) - 10])
upper_reference_points = jnp.array([0, np.max(upper_nodes[:, 1]) + 10])
upper_nodes = upper_nodes.at[:, 1].add(-2)

gap = find_gap(
    contactor_point=lower_nodes[1:-1, :],
    surface_points=upper_nodes[upper_elements],
    reference_points=upper_reference_points,
)
print("gap: ", gap)
gap:  [-0.9989151  -1.30013507 -1.46614989 -1.46614989 -1.30013507 -0.9989151 ]

Lets also visualize the contact pairs using the plot_contact_pairs function.

plt.figure(figsize=(6, 4))
ax = plt.axes()
plot_contact_pairs(
    contactor_points=lower_nodes[1:-1, :],
    target_surface_points=upper_nodes[upper_elements],
    reference_points=upper_reference_points,
    ax=ax,
)
ax.set_aspect("equal")
plt.show()
Figure 13.4: Contact pairs between the lower surface and the upper surface. Contactor points (lower surface) are marked with red crosses and the closest points on the target surface (upper surface) are marked with circles. The target surface is marked with the same color as the closest points.

Defining the contact energy

Now we have a algorithm to compute the gap between a point and a surface. We can use this algorithm to compute the contact energy.

The contact energy is defined as:

\[ \Psi_\text{contact}(u)=\frac{1}{2}k_\text{pen}\int_{\Gamma_c}\langle -g(u)\rangle^2 \text{d}\Gamma \] where \(\Gamma_c\) is the contact surface, \(k_\text{pen}\) is the penalty parameter, and \(g(u)\) is the gap function. The Macaluay’s bracket is defined as:

\[ \langle \star\rangle = \begin{cases} 0 & \text{if } \star > 0 \\ \star & \text{if } \star \leq 0 \end{cases} \]

For this exercise, we define the discrete penalization only at the nodes on the contact surface. Therefore, the contact energy is defined as:

\[ \Psi_\text{contact}(u)=\frac{1}{2}k_\text{pen}\sum_{i=1}^{N_c}\langle -g(u_i)\rangle^2 a_{i} \]

where \(N_c\) is the number of nodes on the contact surface, \(A_i\) is the area of the node \(i\), and \(g(u_i)\) is the gap function at node \(i\).

Defining the penalty parameter

The penalty parameter \(k_\text{pen}\) is a parameter that controls the stiffness of the contact. It is a scalar parameter that is used to penalize the violation of the contact constraint.

Below define the penalty parameter \(k_\text{pen}\).

k_pen = 5.0

Computing the contact energy

Now we have all the ingredients to define the contact energy.

The contact energy is defined as:

\[ \Psi_\text{contact}(u)=\frac{1}{2}k_\text{pen}\sum_{i=1}^{N_c}\langle g(u_i)\rangle^2 a_{i} \]

Below we define the function to compute the contact energy. This function is the same as the one defined in the previous exercise except for the fact, it also takes the reference points as an argument.

@jax.jit
def macalauy_bracket(x):
    return jnp.where(x > 0, 0, x)


@jax.jit
def compute_contact_energy(
    u: Array,
    coords: Array,
    contactor_nodes: Array,
    nodes_area: Array,
    target_surface_elements: Array,
    reference_points: Array,
) -> Array:
    """Compute the contact energy for a given displacement field.
    Args:
        u: Displacement field.
        coords: Coordinates of the nodes.
        contactor_nodes: Indices of the nodes on the contactor surface.
        nodes_area: Area of the nodes on the contactor surface.
        target_surface_elements: Coordinates of the points of target surface.
        reference_points: Coordinates of the reference points of the target surface.
    Returns:
        Contact energy.
    """
    new_coords = coords + u

    points = new_coords[contactor_nodes]
    contactor_nodes_area = nodes_area[contactor_nodes]
    surface_points = new_coords[target_surface_elements]

    # Loop over nodes on the potential contact surface
    def _contact_energy_node(point, contact_area):
        gap = find_gap(point, surface_points, reference_points)
        gap = macalauy_bracket(gap)
        return 0.5 * k_pen * (gap**2) * contact_area

    contact_energy_node = jax.vmap(_contact_energy_node, in_axes=(0, 0))

    return jnp.sum(contact_energy_node(points, contactor_nodes_area))
The One-Pass Algorithm (The Biased Approach)

In the above function to compute the contact energy, we have to make a choice: which body is the “contactor” and which is the “target”? It turns out this seemingly arbitrary choice can introduce a significant bias into the simulation, leading to unphysical results.

In order to understand how the bias is introduced, let us consider a simple example, contact between a triangular surface and a flat surface. For simplicity, we consider two line elements on the triangular surface and one line element on the flat surface.

Lets us first consider the case when the triangular surface is the contactor and the flat surface is the target. Below we define some trial nodes and elements to represent the triangular and flat surfaces. We also define a trial displacement field to represent the displacement of the nodes. The displacement is such that the triangular surface is pushed into the flat surface.

triangle_nodes = jnp.array([[0, 0.4], [0.5, 0.0], [1, 0.4]])
flat_nodes = jnp.array([[0, 0], [0.6, 0.0], [1, 0]])

trial_nodes = jnp.concatenate([triangle_nodes, flat_nodes], axis=0)

triangle_surface_elements = jnp.array([[0, 1], [1, 2]])
triangle_contactor_nodes = jnp.array([1])

flat_surface_elements = jnp.array([[3, 4], [4, 5]])
flat_contactor_nodes = jnp.array([4])

reference_points_triangle_surface = jnp.array([0.5, 2.])
reference_points_flat_surface = jnp.array([0.5, -10])

trial_u = jnp.zeros(trial_nodes.shape[0]*2)
trial_u = trial_u.at[2 * triangle_contactor_nodes + 1].set(-0.2)

Below we define a function to compute the contact energy \(\Psi_\text{contact}\) when the triangular surface is the contactor and the flat surface is the target. We make use of the compute_contact_energy function defined above. We compute the contact force that will be exerted on the nodes of both the surfaces due to the interpenetration.

\[ f_\text{contact} = -\frac{\partial \Psi_\text{contact}}{\partial u} \]

We plot these contact forces on the nodes of the both surfaces. As expected the direction of the contact force is determined by the direction of the normal to the target surface.

Code: Compute the contact energy and the contact force when the triangular surface is the contactor and the flat surface is the target
def energy_when_triangle_is_contactor(u_flat: Array) -> float:
    u = u_flat.reshape(-1, 2)
    return compute_contact_energy(
        u=u,
        coords=trial_nodes,
        contactor_nodes=triangle_contactor_nodes,
        nodes_area=jnp.ones(len(u_flat)),
        target_surface_elements=flat_surface_elements,
        reference_points=reference_points_flat_surface,
    )


trial_compute_contact_force = jax.jacrev(energy_when_triangle_is_contactor)

trial_fint = -trial_compute_contact_force(trial_u.reshape(-1, 2))

trial_new_nodes = trial_nodes + trial_u.reshape(-1, 2)
plt.style.use(STYLE_PATH)
plt.figure(figsize=(4, 3), layout="constrained")
ax = plt.axes()
ax.plot(
    trial_new_nodes[:3, 0],
    trial_new_nodes[:3, 1],
    "o-",
    markersize=6,
    color=colors.red,
    label="Triangle",
)
ax.plot(
    trial_new_nodes[3:, 0],
    trial_new_nodes[3:, 1],
    "o-",
    markersize=6,
    color=colors.green,
    label="Flat",
)


for point, force in zip(
    trial_new_nodes[jnp.unique(triangle_surface_elements)],
    trial_fint[jnp.unique(triangle_surface_elements)],
):
    normal = force / jnp.linalg.norm(force)
    ax.quiver(
        point[0],
        point[1],
        normal[0],
        normal[1],
        angles="xy",
        scale_units="xy",
        scale=10,
        color="gray",
        ls="dashed",
        lw=1,
        zorder=100,
    )

for point, force in zip(
    trial_new_nodes[jnp.unique(flat_surface_elements)],
    trial_fint[jnp.unique(flat_surface_elements)],
):
    normal = force / jnp.linalg.norm(force)
    ax.quiver(
        point[0],
        point[1],
        normal[0],
        normal[1],
        angles="xy",
        scale_units="xy",
        scale=10,
        color="k",
        ls="dashed",
        lw=1,
        zorder=100,
    )


ax.fill_between(
    trial_new_nodes[:3, 0],
    y1=trial_new_nodes[:3, 1],
    y2=0.5,
    color=colors.blue,
    alpha=0.25,
)

ax.fill_between(
    trial_new_nodes[3:, 0],
    y1=trial_new_nodes[3:, 1],
    y2=-0.3,
    color=colors.blue,
    alpha=0.25,
)
ax.text(0.5, 0.35, "Contactor", color="k", fontsize=12)
ax.text(0.1, -0.1, "Target", color="k", fontsize=12)

ax.set_aspect("equal")
ax.axis("off")
plt.show()
Figure 13.5: Contact forces on the nodes of the triangular and flat surfaces when the triangular surface is the contactor and the flat surface is the target.

Next, we consider the case when the flat surface is the contactor and the triangular surface is the target. We repeat the same steps as above to compute the contact energy and the contact force. Upon plotting the contact forces, we observe that the contact forces are in different direction compared to the case when the triangular surface is the contactor. This is because the direction of the contact force is determined by the direction of the normal to the target surface. Earlier, the normal was pointing outwards the flat surface and now it is pointing outwards the triangular surface.

This is exactly the cause of the bias in the contact force.

Code: Compute the contact energy and the contact force when the flat surface is the contactor and the triangular surface is the target
def energy_when_flat_is_contactor(u_flat: Array) -> float:
    u = u_flat.reshape(-1, 2)
    return compute_contact_energy(
        u=u,
        coords=trial_nodes,
        contactor_nodes=flat_contactor_nodes,
        nodes_area=jnp.ones(len(u_flat)),
        target_surface_elements=triangle_surface_elements,
        reference_points=reference_points_triangle_surface,
    )


trial_compute_contact_force = jax.jacrev(energy_when_flat_is_contactor)

trial_fint = -trial_compute_contact_force(trial_u.reshape(-1, 2))


trial_new_nodes = trial_nodes + trial_u.reshape(-1, 2)
plt.style.use(STYLE_PATH)
plt.figure(figsize=(4, 3), layout='constrained')

ax = plt.axes()
ax.plot(
    trial_new_nodes[:3, 0],
    trial_new_nodes[:3, 1],
    "o-",
    markersize=6,
    color=colors.red,
    label="Triangle",
)
ax.plot(
    trial_new_nodes[3:, 0],
    trial_new_nodes[3:, 1],
    "o-",
    markersize=6,
    color=colors.green,
    label="Flat",
)

for point, force in zip(
    trial_new_nodes[jnp.unique(triangle_surface_elements)],
    trial_fint[jnp.unique(triangle_surface_elements)],
):
    normal = force / jnp.linalg.norm(force)
    ax.quiver(
        point[0],
        point[1],
        normal[0],
        normal[1],
        angles="xy",
        scale_units="xy",
        scale=10,
        color="gray",
        ls="dashed",
        lw=1,
        zorder=100
    )

for point, force in zip(
    trial_new_nodes[flat_contactor_nodes],
    trial_fint[flat_contactor_nodes],
):
    normal = force / jnp.linalg.norm(force)
    ax.quiver(
        point[0],
        point[1],
        normal[0],
        normal[1],
        angles="xy",
        scale_units="xy",
        scale=10,
        color="k",
        ls="dashed",
        lw=1,
        zorder=100
    )

ax.fill_between(
    trial_new_nodes[:3, 0],
    y1=trial_new_nodes[:3, 1],
    y2=0.5,
    color=colors.blue,
    alpha=0.25,
)

ax.fill_between(
    trial_new_nodes[3:, 0],
    y1=trial_new_nodes[3:, 1],
    y2=-0.3,
    color=colors.blue,
    alpha=0.25,
)
ax.set_aspect("equal")
ax.axis('off')
ax.text(0.5, 0.35, "Target", color='k', fontsize=12)
ax.text(0.1, -0.1, "Contactor", color='k', fontsize=12)
plt.show()
Figure 13.6: Contact forces on the nodes of the triangular and flat surfaces when the flat surface is the contactor and the triangular surface is the target.
The One-Pass Algorithm (The Biased Approach)

The above approach for selecting one contactor surface and one target surface is called the one-pass algorithm. As shown above, this approach is biased and can lead to unphysical results. Therefore, when we are dealing with contact of deformable bodies, we use a two-pass or symmetrized approach.

The Two-Pass Algorithm (The Symmetrized Approach)

To eliminate this bias and create a robust algorithm, we use a two-pass or symmetrized approach. The logic is simple but powerful: we perform the contact check twice, swapping the roles of the bodies each time.

  1. Pass 1: Treat Body A as the contactor and Body B as the target. Calculate the contact penalty energy based on any penetrations found (\(\Psi_{A \to B}\)).
  2. Pass 2: Swap the roles. Now treat Body B as the contactor and Body A as the target. Calculate a separate contact penalty energy (\(\Psi_{B \to A}\)).
  3. Combine: The total contact energy is the average of the two passes: \[\Psi_{\text{contact}} = \frac{1}{2} (\Psi_{A \to B} + \Psi_{B \to A})\]

By performing two passes, we ensure that no matter which body’s nodes are penetrating the other’s segments, a resisting penalty force will always be generated. This symmetrization eliminates the bias and leads to a much more accurate and physically realistic simulation, especially when the two bodies have similar mesh densities or complex geometries

Below, we define the functions to compute the contact energy and the contact force using the two-pass algorithm.

Code: Compute the contact energy using the two-pass algorithm
def energy_using_two_pass(u_flat: Array) -> float:
    u = u_flat.reshape(-1, 2)
    energy_A_to_B = compute_contact_energy(
        u=u,
        coords=trial_nodes,
        contactor_nodes=triangle_contactor_nodes,
        nodes_area=jnp.ones(len(u_flat)),
        target_surface_elements=flat_surface_elements,
        reference_points=reference_points_flat_surface,
    )

    energy_B_to_A = compute_contact_energy(
        u=u,
        coords=trial_nodes,
        contactor_nodes=flat_contactor_nodes,
        nodes_area=jnp.ones(len(u_flat)),
        target_surface_elements=triangle_surface_elements,
        reference_points=reference_points_triangle_surface,
    )

    return 0.5 * (energy_A_to_B + energy_B_to_A)



trial_compute_contact_force = jax.jacrev(energy_using_two_pass)

trial_fint = -trial_compute_contact_force(trial_u.reshape(-1, 2))

trial_new_nodes = trial_nodes + trial_u.reshape(-1, 2)
plt.style.use(STYLE_PATH)
plt.figure(figsize=(4, 3), layout='constrained')

ax = plt.axes()
ax.plot(
    trial_new_nodes[:3, 0],
    trial_new_nodes[:3, 1],
    "o-",
    markersize=6,
    color=colors.red,
    label="Triangle",
)
ax.plot(
    trial_new_nodes[3:, 0],
    trial_new_nodes[3:, 1],
    "o-",
    markersize=6,
    color=colors.green,
    label="Flat",
)

for point, force in zip(
    trial_new_nodes[jnp.unique(triangle_surface_elements)],
    trial_fint[jnp.unique(triangle_surface_elements)],
):
    normal = force / jnp.linalg.norm(force)
    ax.quiver(
        point[0],
        point[1],
        normal[0],
        normal[1],
        angles="xy",
        scale_units="xy",
        scale=10,
        color="gray",
        ls="dashed",
        lw=1,
        zorder=100
    )

for point, force in zip(
    trial_new_nodes[flat_contactor_nodes],
    trial_fint[flat_contactor_nodes],
):
    normal = force / jnp.linalg.norm(force)
    ax.quiver(
        point[0],
        point[1],
        normal[0],
        normal[1],
        angles="xy",
        scale_units="xy",
        scale=10,
        color="k",
        ls="dashed",
        lw=1,
        zorder=100
    )


ax.fill_between(
    trial_new_nodes[:3, 0],
    y1=trial_new_nodes[:3, 1],
    y2=0.5,
    color=colors.blue,
    alpha=0.25,
)

ax.fill_between(
    trial_new_nodes[3:, 0],
    y1=trial_new_nodes[3:, 1],
    y2=-0.3,
    color=colors.blue,
    alpha=0.25,
)
ax.set_aspect("equal")
ax.axis('off')
ax.text(0.5, 0.35, "Contactor/Target", color='k', fontsize=12)
ax.text(0.05, -0.1, "Target/Contactor", color='k', fontsize=12)
plt.show()
Figure 13.7: Contact forces on the nodes of the triangular and flat surfaces using the two-pass algorithm
Important

The grading of the assignment will be based on the correctness of the algorithm. The part below will not be considered for grading. However, we encourage you to also go through the part below as it uses the algorithm implemented above and solve the problem of contact between two deformable bodies.

Please go through the part below to understand the problem statement and the solution.

Model setup

Next, we use the contact detection algorithm and the concept of two-pass algorithm to resolve the contact between two deformable bodies. Below, we generate a triangular mesh of a half-circle. We also define a function to get the elements and nodes on the potential contact surface.

Code: Functions to generate the mesh
def generate_block_mesh(
    nx: int,
    ny: int,
    lxs: Tuple[float, float],
    lys: Tuple[float, float],
    curve_func: Optional[Callable[[Array, float], bool]] = None,
    tol: float = 1e-6,
) -> Tuple[Array, Array, Optional[Array]]:
    """
    Generates a 2D triangular mesh for a rectangle and optionally extracts
    1D line elements along a specified curve.

    Args:
        nx: Number of elements along the x-direction.
        ny: Number of elements along the y-direction.
        lxs: Tuple of the x-coordinates of the left and right edges of the rectangle.
        lys: Tuple of the y-coordinates of the bottom and top edges of the rectangle.
        curve_func: An optional callable that takes a coordinate array [x, y] and
                    a tolerance, returning True if the point is on the curve.
        tol: Tolerance for floating-point comparisons.

    Returns:
        A tuple containing:
        - coords (jnp.ndarray): Nodal coordinates, shape (num_nodes, 2).
        - elements_2d (jnp.ndarray): 2D triangular element connectivity.
        - elements_1d (jnp.ndarray | None): 1D line element connectivity, or None.
    """

    x = jnp.linspace(lxs[0], lxs[1], nx + 1)
    y = jnp.linspace(lys[0], lys[1], ny + 1)
    xv, yv = jnp.meshgrid(x, y, indexing="ij")
    coords = jnp.stack([xv.ravel(), yv.ravel()], axis=-1)

    def node_id(i, j):
        return i * (ny + 1) + j

    elements_2d = []
    for i in range(nx):
        for j in range(ny):
            n0 = node_id(i, j)
            n1 = node_id(i + 1, j)
            n2 = node_id(i, j + 1)
            n3 = node_id(i + 1, j + 1)
            elements_2d.append([n0, n1, n3])
            elements_2d.append([n0, n3, n2])
    elements_2d = jnp.array(elements_2d)

    # --- 2. Extract 1D elements if a curve function is provided ---
    if curve_func is None:
        return coords, elements_2d, None

    # Efficiently find all nodes on the curve using jax.vmap
    on_curve_mask = jax.vmap(lambda c: curve_func(c, tol))(coords)

    elements_1d = []
    # Iterate through all 2D elements to find edges on the curve
    for tri in elements_2d:
        # Define the three edges of the triangle
        edges = [(tri[0], tri[1]), (tri[1], tri[2]), (tri[2], tri[0])]
        for n_a, n_b in edges:
            # If both nodes of an edge are on the curve, add it to the set
            if on_curve_mask[n_a] and on_curve_mask[n_b]:
                # Sort to store canonical representation, e.g., (1, 2) not (2, 1)
                elements_1d.append(tuple(sorted((n_a, n_b))))

    if not elements_1d:
        return coords, elements_2d, jnp.array([], dtype=int)

    return coords, elements_2d, jnp.unique(jnp.array(elements_1d), axis=0)


def generate_two_blocks_mesh(
    upper_ns: Tuple[int, int],
    lower_ns: Tuple[int, int],
    upper_lengths: Tuple[float, float],
    lower_lengths: Tuple[float, float],
    gap: float,
    upper_curve_func: Callable[[Array, float], bool],
    lower_curve_func: Callable[[Array, float], bool],
    tol: float = 1e-6,
) -> Tuple[Array, Array, Optional[Array]]:
    """
    Generates a 2D triangular mesh for two blocks and optionally extracts
    1D line elements along a specified curve.

    Args:
        nx: Number of elements along the x-direction.
        ny: Number of elements along the y-direction.
        lxs: Tuple of the x-coordinates of the left and right edges of the rectangle.
        lys: Tuple of the y-coordinates of the bottom and top edges of the rectangle.
        curve_func: An optional callable that takes a coordinate array [x, y] and
                    a tolerance, returning True if the point is on the curve.
        tol: Tolerance for floating-point comparisons.

    Returns:
        A tuple containing:
        - coords (jnp.ndarray): Nodal coordinates, shape (num_nodes, 2).
        - elements_2d (jnp.ndarray): 2D triangular element connectivity.
        - elements_1d (jnp.ndarray | None): 1D line element connectivity, or None.
    """

    upper_nx, upper_ny = upper_ns
    lower_nx, lower_ny = lower_ns
    upper_lx, upper_ly = upper_lengths
    lower_lx, lower_ly = lower_lengths
    upper_coords, upper_elements_2d, upper_elements_1d = generate_block_mesh(
        nx=upper_nx, ny=upper_ny, lxs=(-upper_lx/2, upper_lx/2), lys=(0, upper_ly), curve_func=upper_curve_func, tol=tol
    )

    lower_coords, lower_elements_2d, lower_elements_1d = generate_block_mesh(
        nx=lower_nx, ny=lower_ny, lxs=(-lower_lx/2, lower_lx/2), lys=(-(lower_ly+gap), -gap), curve_func=lower_curve_func, tol=tol
    )

    coords = jnp.vstack((upper_coords, lower_coords))
    elements = jnp.vstack(
        (upper_elements_2d, lower_elements_2d + upper_coords.shape[0])
    )

    mesh = Mesh(coords, elements)

    lower_elements_1d = lower_elements_1d + upper_coords.shape[0]

    return (
        mesh,
        upper_elements_2d,
        lower_elements_2d + upper_coords.shape[0],
        upper_elements_1d,
        lower_elements_1d,
    )
upper_ns = (6, 6)  # Number of elements in X
lower_ns = (7, 7)  # Number of elements in Y
upper_lengths = (4, 4)
lower_lengths = (8, 8)
gap = 0.4 # initial gap between bodies


# function to identify the potential contact elements on the upper body
def upper_block_line(coord: Array, tol: float) -> bool:
    return jnp.logical_and(jnp.isclose(coord[1], 0.0, atol=tol), coord[0] >= -upper_lengths[0])


# function to identify the potential contact elements on the lower body
def lower_block_line(coord: Array, tol: float) -> bool:
    return jnp.logical_and(jnp.isclose(coord[1], -gap, atol=tol), coord[0] >= -lower_lengths[0])


(
    mesh,
    upper_elements,
    lower_elements,
    upper_contact_elements,
    lower_contact_elements,
) = generate_two_blocks_mesh(
    upper_ns=upper_ns,
    lower_ns=lower_ns,
    upper_lengths=upper_lengths,
    lower_lengths=lower_lengths,
    gap=gap,
    upper_curve_func=upper_block_line,
    lower_curve_func=lower_block_line,
)

# get the nodes from the contact elements
upper_contact_nodes = jnp.unique(upper_contact_elements.flatten())
lower_contact_nodes = jnp.unique(lower_contact_elements.flatten())

n_nodes = mesh.coords.shape[0]
n_dofs_per_node = 2
n_dofs = n_dofs_per_node * n_nodes
Code: Plot the mesh
plt.style.use(STYLE_PATH)
plt.figure(figsize=(6, 3))
ax = plt.axes()
ax.tripcolor(
    mesh.coords[:, 0],
    mesh.coords[:, 1],
    upper_elements,
    edgecolors="black",
    linewidths=0.2,
    shading="flat",
    cmap="managua_r",
    facecolors=np.ones(upper_elements.shape[0]),
)

ax.tripcolor(
    mesh.coords[:, 0],
    mesh.coords[:, 1],
    lower_elements,
    edgecolors="black",
    linewidths=0.2,
    shading="flat",
    cmap="managua_r",
    facecolors=np.zeros(lower_elements.shape[0]),
)


ax.scatter(
    mesh.coords[upper_contact_nodes, 0], mesh.coords[upper_contact_nodes, 1], color=colors.red, s=5
)

ax.scatter(
    mesh.coords[lower_contact_nodes, 0], mesh.coords[lower_contact_nodes, 1], color=colors.green, s=5
)



ax.set_aspect("equal")
ax.margins(0, 0)
plt.show()
Figure 13.8: Mesh with two deformable bodies. The potential contact nodes are highlighted in red for the upper body and in green for the lower body.

Choose a reference point for the upper body and another reference point for the lower body.

lower_reference_points = jnp.array([0, -lower_lengths[1]*10])
upper_reference_points = jnp.array([0, upper_lengths[1]*10])

Defining material parameters

from typing import NamedTuple


class Material(NamedTuple):
    """Material properties for the linear elastic operator."""

    mu: float  # Shear modulus
    lmbda: float  # First Lamé parameter


E = 1.0
nu = 0.3
mu = E / (2 * (1 + nu))
lmbda = E * nu / ((1 + nu) * (1 - 2 * nu))

mat = Material(mu=mu, lmbda=lmbda)

Defining elastic energy of the system

\[ \Psi_\text{e}(\boldsymbol{u}) = \frac{1}{2} \int_\Omega \psi_e d\Omega \]

where

\[ \psi_\text{e}(x) = \boldsymbol{\sigma}(x) : \boldsymbol{\varepsilon}(x) \]

where \(\boldsymbol{\sigma}\) is the stress tensor and \(\boldsymbol{\varepsilon}\) is the strain tensor.

\[ \boldsymbol{\sigma} = \lambda \text{tr}(\boldsymbol{\varepsilon}) \mathbf{I} + 2\mu \boldsymbol{\varepsilon} \]

and

\[ \boldsymbol{\varepsilon} = \frac{1}{2} (\nabla \boldsymbol{u} + \nabla \boldsymbol{u}^T) \]

tri = element.Tri3()
op = Operator(mesh, tri)
mat = Material(mu=mu, lmbda=lmbda)

@autovmap(grad_u=2)
def compute_strain(grad_u: Array) -> Array:
    """Compute the strain tensor from the gradient of the displacement."""
    return 0.5 * (grad_u + grad_u.T)


@autovmap(eps=2)
def compute_stress(eps: Array, mu: float, lmbda: float) -> Array:
    """Compute the stress tensor from the strain tensor."""
    I = jnp.eye(2)
    return 2 * mu * eps + lmbda * jnp.trace(eps) * I


@autovmap(grad_u=2)
def strain_energy(grad_u: Array, mu: float, lmbda: float) -> Array:
    """Compute the strain energy density."""
    eps = compute_strain(grad_u)
    sig = compute_stress(eps, mu, lmbda)
    return 0.5 * jnp.einsum("ij,ij->", sig, eps)


@jax.jit
def total_strain_energy(u_flat: Array) -> float:
    """Compute the total energy of the system."""
    u = u_flat.reshape(-1, n_dofs_per_node)
    grad_u = op.grad(u)
    return op.integrate(strain_energy(grad_u, mat.mu, mat.lmbda))

Defining contact energy

We will now define the contact energy using the two-pass algorithm. Use the function compute_contact_energy to compute the contact energy for \(\Psi_{A \to B}\) and \(\Psi_{B \to A}\) and average the two to get the total contact energy.

Below we define the function to compute the area associated with each contact node for the upper contact surface as well as the lower contact surface. Use these two vectors to compute the total contact energy.

Computing the area associated with each contact node

For correct resolution of the contact problem, we need the area associated with each contact node. In order to compute the area of each contact node, we define a line mesh that contains the contact nodes and the edges of the contact elements. We then define a function which integrates a constant field of value 1 over the line mesh.

\[ a_\text{total} = \int_{\Gamma_{c}} q(x) \text{d}\Gamma \]

where \(q(x)\) is a constant field of value 1. We can then compute the area of each node by differentiating the total area with respect to the field \(q(x)\).

\[ a_i = \frac{\partial a_\text{total}}{\partial q(x_i)} \]

Note

Below we compute the area associated with all the nodes in the mesh and store in the array nodes_area.

Thus the size of the array nodes_area = n_nodes i.e number of nodes in the mesh.

Later, we will use nodes_area to extract the area associated with the contact nodes.

line2 = element.Line2()
lower_line_mesh = Mesh(mesh.coords, lower_contact_elements)
lower_line_op = Operator(lower_line_mesh, line2)

upper_line_mesh = Mesh(mesh.coords, upper_contact_elements)
upper_line_op = Operator(upper_line_mesh, line2)


@jax.jit
def total_lower_area(u_flat: Array) -> float:
    u = u_flat.reshape(-1, 1)
    u_quad = lower_line_op.eval(u)
    return jnp.sum(lower_line_op.integrate(u_quad))

@jax.jit
def total_upper_area(u_flat: Array) -> float:
    u = u_flat.reshape(-1, 1)
    u_quad = upper_line_op.eval(u)
    return jnp.sum(upper_line_op.integrate(u_quad))


ones = jnp.ones(n_nodes)
upper_area_vector = jax.jacrev(total_upper_area)(ones)
lower_area_vector = jax.jacrev(total_lower_area)(ones)

Use two-pass algorithm to compute the total contact energy.

\[\Psi_{\text{contact}} = \frac{1}{2} (\Psi_{A \to B} + \Psi_{B \to A})\]

contact_energy_upper = partial(
    compute_contact_energy,
    coords=mesh.coords,
    contactor_nodes=upper_contact_nodes,
    nodes_area=upper_area_vector,
    target_surface_elements=lower_contact_elements,
    reference_points=lower_reference_points,
)

contact_energy_lower = partial(
    compute_contact_energy,
    coords=mesh.coords,
    contactor_nodes=lower_contact_nodes,
    nodes_area=lower_area_vector,
    target_surface_elements=upper_contact_elements,
    reference_points=upper_reference_points,
)

Compute the total energy

Now lets us define the total energy function.

\[ \Psi(u) = \Psi_\text{elastic}(u) + \dfrac{1}{2}\big(\Psi_{A \to B} + \Psi_{B \to A} \big) \]

Hint

The total energy function will look like this:


def total_energy(
    u_flat: Array,
    coords: Array,
    ...
) -> float:
    """Compute the total energy for a given displacement field.
    Args:
        u_flat: Flattened displacement field, shape (n_dofs,).
        coords: Initial coordinates of the nodes, shape (n_nodes, 2).
        ... Other arguments need to be passed to the compute_contact_energy function for each body.
    Returns:
        Total energy, a scalar
    """
    # TODO: reshape the displacement field
    ...

    # TODO compute the total contact energy using two-pass algorithm
    contact_energy = ...

    # TODO compute the elastic energy
    strain_energy = ...

    # TODO Sum the contact and strain energies to get the total energy of the system.
    return ...
@jax.jit
def _total_energy(
    u_flat: Array,
    coords: Array,
) -> Array:
    """Compute the total energy for a given displacement field.
    Args:
        u_flat: Flattened displacement field.
        coords: Coordinates of the nodes.
        contact_nodes: Indices of the nodes on the contact surface.
    Returns:
        Total energy.
    """
    u = u_flat.reshape(-1, n_dofs_per_node)

    upper_contact_energy = contact_energy_upper(u)
    lower_contact_energy = contact_energy_lower(u)

    contact_energy = 0.5 * (upper_contact_energy + lower_contact_energy)

    strain_energy = total_strain_energy(u)
    return strain_energy + contact_energy


total_energy = partial(
    _total_energy,
    coords=mesh.coords,
)

Applying Dirichlet boundary conditions

We push the upper deformable body into the lower deformable body by applying a displacement in the \(y\) direction on the top edge of the upper body. To avoid rigid body motion, we fix the displacement in the \(x\) direction on the top edge of the upper body. The lower body is fixed on the bottom edge in both the \(x\) and \(y\) directions.

The boundary conditions thus applied are:

\[ \begin{aligned} u_x &= 0 \quad \text{on the top edge of the upper body} \\ u_y &= -1.0 \quad \text{on the top edge of the upper body} \\ u_x &= 0 \quad \text{on the bottom edge of the lower body} \\ u_y &= 0 \quad \text{on the bottom edge of the lower body} \\ \end{aligned} \]

where \(R\) is the radius of the cylinder.

y_max = jnp.max(mesh.coords[:, 1])
y_min = jnp.min(mesh.coords[:, 1])
x_min = jnp.min(mesh.coords[:, 0])
x_max = jnp.max(mesh.coords[:, 0])
height = y_max - y_min


upper_nodes = jnp.where(jnp.isclose(mesh.coords[:, 1], y_max))[0]
lower_nodes = jnp.where(jnp.isclose(mesh.coords[:, 1], y_min))[0]
left_nodes = jnp.where(jnp.isclose(mesh.coords[:, 0], x_min))[0]
right_nodes = jnp.where(jnp.isclose(mesh.coords[:, 0], x_max))[0]

upper_block_nodes = jnp.unique(upper_elements.flatten())

lower_block_nodes = jnp.unique(lower_elements.flatten())

fixed_dofs = jnp.concatenate(
    [
        2 * upper_nodes,
        2 * upper_nodes + 1,
        2 * lower_nodes,
        2 * lower_nodes + 1,
    ]
)


applied_disp = 1.0

prescribed_values = jnp.zeros(n_dofs)
prescribed_values = prescribed_values.at[2 * upper_block_nodes + 1].set(-applied_disp)

free_dofs = jnp.setdiff1d(jnp.arange(n_dofs), fixed_dofs)

Sparsity Pattern for two deformable bodies

Since the contact pair (a contractor node and its closet target surface element) can change on every iteration, we need to use a more sophisticated solver than the one used in the last exercise. Below, we present the sparsity pattern for the above problem at 3 different time instants. The off-diagonal terms are due to the contact constraints between the two bodies. We can see that the sparsity pattern is changing as the contact pairs are changing. Initially, there are no contact constraints and the sparsity pattern is the same as the one for the linear elastic problem. As the contact starts, the sparsity pattern changes and the contact constraints are included in the stiffness matrix. Depending on which node is in contact, the sparsity pattern changes.

Figure 13.9: The sparsity pattern due to contact constraints at different time during loading. Left: initial configuration. Middle: at 80% of the applied displacement. Right: at 100% of the applied displacement.

Instead of buidling sparsity pattern again and again, we can make use of the matrix-free solvers (see Matrix-free solvers) to solve the linear system. In matrix-free solvers, we do not need to build the stiffness matrix explicitly. We can directly apply the stiffness operator to the displacement vector to get the residual. Then one can use the conjugate gradient method to solve the linear system.

Note

You might be wondering why we were building the stiffness matrix in the the last exercises. The reason is that all the last exercises we did were forming a KKT system of equations. If you recall, the KKT system of equations is of the form:

\[ \begin{bmatrix} \mathbf{K} & \mathbf{B}^T \\ \mathbf{B} & \mathbf{0} \end{bmatrix} \begin{bmatrix} \boldsymbol{u} \\ \boldsymbol{\lambda} \end{bmatrix} = \begin{bmatrix} \boldsymbol{f} \\ \boldsymbol{0} \end{bmatrix} \] And such systems are not positive definite (because of the \(\mathbf{0}\) block) and hence we needed a direct solver to solve the linear system. Conjugate gradient method, on the other hand, is applicable to positive definite systems and therefore, we never used matrix-free solvers.

While solving any constrained problem using Penalty methods, we never form a KKT system. For example, in the case of contact, we form a system of the form:

\[ \begin{bmatrix} \mathbf{K}_\text{A} & \mathbf{B}^T \\ \mathbf{B} & \mathbf{K}_\text{B} \end{bmatrix} \begin{bmatrix} \boldsymbol{u}_\text{A} \\ \boldsymbol{u}_\text{B} \end{bmatrix} = \begin{bmatrix} \boldsymbol{f}_\text{A} \\ \boldsymbol{f}_\text{B} \end{bmatrix} \]

Notice, there is no \(\mathbf{0}\) block in the above form.

Below, we define the functions to compute the gradient and the Jacobian-vector product. Within the Jacoboian-Vector product we apply Lifting approach by first “zeroing” out the displacements and then “zeroing” out the residual at the fixed DOFs.

# creating functions to compute the gradient and
gradient = jax.jacrev(total_energy)


# create a function to compute the JVP product
@jax.jit
def compute_tangent(du, u_prev):
    du_projected = du.at[fixed_dofs].set(0.0)
    tangent = jax.jvp(gradient, (u_prev,), (du_projected,))[1]
    tangent = tangent.at[fixed_dofs].set(0)
    return tangent

Below we define the Conjugate-gradient method and the Newton-Rahpson method to solve the nonlinear problem of contact.

Code: Functions to implement the conjugate gradient method and Newton-Rahpson method
def conjugate_gradient(A, b, atol=1e-8, max_iter=100):

    iiter = 0

    def body_fun(state):
        b, p, r, rsold, x, iiter = state
        Ap = A(p)
        alpha = rsold / jnp.vdot(p, Ap)
        x = x + jnp.dot(alpha, p)
        r = r - jnp.dot(alpha, Ap)
        rsnew = jnp.vdot(r, r)
        p = r + (rsnew / rsold) * p
        rsold = rsnew
        iiter = iiter + 1
        return (b, p, r, rsold, x, iiter)

    def cond_fun(state):
        b, p, r, rsold, x, iiter = state
        return jnp.logical_and(jnp.sqrt(rsold) > atol, iiter < max_iter)

    x = jnp.full_like(b, fill_value=0.0)
    r = b - A(x)
    p = r
    rsold = jnp.vdot(r, p)

    b, p, r, rsold, x, iiter = jax.lax.while_loop(
        cond_fun, body_fun, (b, p, r, rsold, x, iiter)
    )
    return x, iiter


def newton_krylov_solver(
    u,
    fext,
    gradient,
    compute_tangent,
    fixed_dofs,
):
    fint = gradient(u)

    du = jnp.zeros_like(u)

    iiter = 0
    norm_res = 1.0

    tol = 1e-8
    max_iter = 10

    while norm_res > tol and iiter < max_iter:
        residual = fext - fint
        residual = residual.at[fixed_dofs].set(0)
        A = jax.jit(partial(compute_tangent, u_prev=u))
        du, cg_iiter = conjugate_gradient(A=A, b=residual, atol=1e-8, max_iter=100)

        u = u.at[:].add(du)
        fint = gradient(u)
        residual = fext - fint
        residual = residual.at[fixed_dofs].set(0)
        norm_res = jnp.linalg.norm(residual)
        print(f"  Residual: {norm_res:.2e}")
        iiter += 1

    return u, norm_res

Implementing the Incremental Loading Loop

Now that we have the mesh, the energy functions, and the newton_krylov_solver, it’s time to set up the main simulation loop. We will apply the displacement in small increments and solve for equilibrium at each step. This is important because using penalty method, the problem becomes non-linear and the solution is sensitive to the initial guess.

Follow these steps to build the main part of your script.

Hint

Initialization

Before the loop, you need to initialize all the variables for the simulation.

  • Create the initial solution vector u_prev, which contains both displacements. Initialize it as a vector of zeros with the correct total size (n_dofs).
  • Initialize the external force vector fext as a vector of zeros.
  • Set the number of load increments, n_steps (e.g., 20). We would apply the total applied_displacement in n_steps.`
  • Calculate the displacement to apply at each step.

The Main Loading Loop

Create a for loop that iterates through each load step.

for step in range(n_steps):
    print(f"Step {step+1}/{n_steps}")
    # ... Your code for each step will go inside this loop ...

Inside the Loop: Apply Displacement and Prepare Solver

At the beginning of each loop iteration, you need to update the boundary conditions and prepare the functions for the solver.

  • Update the u_prev vector with the prescribed displacement for the current step.

Inside the Loop: Solve for Equilibrium

  • Call the newton_krylov_solver function with the initial guess and other arguments. Depending on what inputs your gradient and compute_tangent functions require, you may need to pass additional arguments to the newton_krylov_solver function. You will also have to pass the fixed_dofs to the newton_krylov_solver function to apply the lifting method.
  • Store the new solution in u_new.
  • Update the state for the next iteration by setting u_prev = u_new. This is basically providing the solution from the previous step as the initial guess for the next step., which helps the solver converge faster.

After the Loop: Final Solution

Once the loop is complete, extract the final displacement field from the final u_prev vector for visualization.

u_prev = jnp.zeros(n_dofs)
fext = jnp.zeros(n_dofs)

n_steps = 20

applied_displacement = prescribed_values / n_steps  # displacement increment

u_per_step = []
for step in range(n_steps):
    print(f"Step {step+1}/{n_steps}")
    u_prev = u_prev.at[fixed_dofs].add(applied_displacement[fixed_dofs])

    u_new, rnorm = newton_krylov_solver(
        u_prev,
        fext,
        gradient,
        compute_tangent,
        fixed_dofs,
    )

    u_prev = u_new
    u_per_step.append(u_prev.reshape(n_nodes, n_dofs_per_node))

u_solution = u_prev.reshape(n_nodes, n_dofs_per_node)
Plotting the stress field
# squeeze to remove the quad point dimension (only 1 quad point)
grad_u = op.grad(u_solution).squeeze()
strains = compute_strain(grad_u)
stresses = compute_stress(strains, mat.mu, mat.lmbda)

plt.style.use(STYLE_PATH)
fig, axs = plt.subplots(1, 2,figsize=(6, 4), layout="constrained")

plot_element_values(
    u=u_per_step[-1],
    mesh=mesh,
    values=stresses[:, 0, 1].flatten(),
    ax=axs[0],
    label=r"$\sigma_{yy}$",
)

axs[0].axhline(y=0, color="black", linewidth=0.5)
axs[0].set_xlabel("$x$")
axs[0].set_ylabel("$y$")
axs[0].set_aspect("equal")
axs[0].margins(0.0, 0.0)

plot_nodal_values(
    u=u_per_step[-1],
    mesh=mesh,
    nodal_values=u_solution[:,  1].flatten(),
    ax=axs[1],
    label=r"$u_y$",
    shading="flat",
    edgecolors="black",
)

axs[1].axhline(y=0, color="black", linewidth=0.5)
axs[1].set_xlabel("$x$")
axs[1].set_ylabel("$y$")
axs[1].set_aspect("equal")
axs[1].margins(0.0, 0.0)


plt.show()
Figure 13.10: Stress field \(\sigma_{yy}\) after the contact is resolved. The undeformed configuration is shown in the background.