import jax jax.config.update("jax_enable_x64", True) # use double-precision jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) jax.config.update("jax_platforms", "cpu") import jax.numpy as jnp from jax import Array from tatva import Mesh, Operator, element, sparse from jax_autovmap import autovmap from typing import Callable, Optional, Tuple, NamedTuple import numpy as np class Material(NamedTuple): """Material properties for the elasticity operator.""" mu: float # Diffusion coefficient lmbda: float # Diffusion coefficient class CohesiveMaterial(NamedTuple): """Material properties for the elasticity operator.""" Gamma: float # Fracture energy sigma_c: float # Critical stress penalty: float # Penalty parameter def generate_mesh_with_line_elements( nx: int, ny: int, lxs: Tuple[float, float], lys: Tuple[float, float], curve_func: Optional[Callable[[Array, float], bool]] = None, tol: float = 1e-6, ) -> Tuple[Array, Array, Optional[Array]]: """ Generates a 2D triangular mesh for a rectangle and optionally extracts 1D line elements along a specified curve. Args: nx: Number of elements along the x-direction. ny: Number of elements along the y-direction. lxs: Tuple of the x-coordinates of the left and right edges of the rectangle. lys: Tuple of the y-coordinates of the bottom and top edges of the rectangle. curve_func: An optional callable that takes a coordinate array [x, y] and a tolerance, returning True if the point is on the curve. tol: Tolerance for floating-point comparisons. Returns: A tuple containing: - coords (jnp.ndarray): Nodal coordinates, shape (num_nodes, 2). - elements_2d (jnp.ndarray): 2D triangular element connectivity. - elements_1d (jnp.ndarray | None): 1D line element connectivity, or None. """ x = jnp.linspace(lxs[0], lxs[1], nx + 1) y = jnp.linspace(lys[0], lys[1], ny + 1) xv, yv = jnp.meshgrid(x, y, indexing="ij") coords = jnp.stack([xv.ravel(), yv.ravel()], axis=-1) def node_id(i, j): return i * (ny + 1) + j elements_2d = [] for i in range(nx): for j in range(ny): n0 = node_id(i, j) n1 = node_id(i + 1, j) n2 = node_id(i, j + 1) n3 = node_id(i + 1, j + 1) # elements_2d.append([n0, n1, n3]) # elements_2d.append([n0, n3, n2]) elements_2d.append([n0, n1, n2]) elements_2d.append([n1, n3, n2]) elements_2d = jnp.array(elements_2d) # Extract 1D elements if a curve function is provided --- if curve_func is None: return coords, elements_2d, None # Efficiently find all nodes on the curve using jax.vmap on_curve_mask = jax.vmap(lambda c: curve_func(c, tol))(coords) elements_1d = [] # Iterate through all 2D elements to find edges on the curve for tri in elements_2d: # Define the three edges of the triangle edges = [(tri[0], tri[1]), (tri[1], tri[2]), (tri[2], tri[0])] for n_a, n_b in edges: # If both nodes of an edge are on the curve, add it to the set if on_curve_mask[n_a] and on_curve_mask[n_b]: # Sort to store canonical representation, e.g., (1, 2) not (2, 1) elements_1d.append(tuple(sorted((n_a, n_b)))) if not elements_1d: return coords, elements_2d, jnp.array([], dtype=int) return coords, elements_2d, jnp.unique(jnp.array(elements_1d), axis=0) def extract_interface_mesh(mesh, line_elements): """Generate a mesh for the interface between two materials.""" # --- Interface mesh --- line_element_nodes = jnp.unique(line_elements.flatten()) interface_coords = mesh.coords[line_element_nodes] interface_elements = jnp.array( [[index, index + 1] for index in range(len(line_elements))] ) interface_mesh = Mesh(interface_coords, interface_elements) return interface_mesh def get_cohesive_nodes(mesh, bottom_interface_elements, top_interface_elements): bottom_cohesive_nodes = jnp.unique(bottom_interface_elements.flatten()) top_cohesive_nodes = jnp.unique(top_interface_elements.flatten()) bottom_cohesive_coords = mesh.coords[bottom_cohesive_nodes] top_cohesive_coords = mesh.coords[top_cohesive_nodes] bottom_cohesive_nodes = bottom_cohesive_nodes[ jnp.argsort(bottom_cohesive_coords[:, 0]) ] top_cohesive_nodes = top_cohesive_nodes[jnp.argsort(top_cohesive_coords[:, 0])] return bottom_cohesive_nodes, top_cohesive_nodes @autovmap(grad_u=2) def compute_strain(grad_u): return 0.5 * (grad_u + grad_u.T) @autovmap(eps=2, mu=0, lmbda=0) def compute_stress(eps, mu, lmbda): I = jnp.eye(2) return 2 * mu * eps + lmbda * jnp.trace(eps) * I @autovmap(grad_u=2, mu=0, lmbda=0) def strain_energy(grad_u, mu, lmbda): eps = compute_strain(grad_u) sigma = compute_stress(eps, mu, lmbda) return 0.5 * jnp.einsum("ij,ij->", sigma, eps) @jax.jit def safe_sqrt(x): return jnp.sqrt(jnp.where(x > 0.0, x, 0.0)) @autovmap(jump=1) def compute_opening(jump: Array) -> float: """ Compute the opening of the cohesive element. Args: jump: The jump in the displacement field. Returns: The opening of the cohesive element. """ opening = safe_sqrt(jump[0] ** 2 + jump[1] ** 2) return opening @autovmap(jump=1) def exponential_cohesive_energy( jump: Array, Gamma: float, sigma_c: float, penalty: float, delta_threshold: float = 1e-8, ) -> float: """ Compute the cohesive energy for a given jump. Args: jump: The jump in the displacement field. Gamma: Fracture energy of the material. sigma_c: The critical strength of the material. penalty: The penalty parameter for penalizing the interpenetration. delta_threshold: The threshold for the delta parameter. Returns: The cohesive energy. """ delta = compute_opening(jump) delta_c = (Gamma * jnp.exp(-1)) / sigma_c def true_fun(delta): return Gamma * (1 - (1 + (delta / delta_c)) * (jnp.exp(-delta / delta_c))) def false_fun(delta): return 0.5 * penalty * delta**2 return jax.lax.cond(delta > delta_threshold, true_fun, false_fun, delta) from functools import partial import equinox as eqx #@partial(jax.jit, static_argnames=["A", "atol", "max_iter"]) @eqx.filter_jit def conjugate_gradient(A, b, atol=1e-8, max_iter=100): iiter = 0 def body_fun(state): b, p, r, rsold, x, iiter = state Ap = A(p) alpha = rsold / jnp.vdot(p, Ap) x = x + jnp.dot(alpha, p) r = r - jnp.dot(alpha, Ap) rsnew = jnp.vdot(r, r) p = r + (rsnew / rsold) * p rsold = rsnew iiter = iiter + 1 return (b, p, r, rsold, x, iiter) def cond_fun(state): b, p, r, rsold, x, iiter = state return jnp.logical_and(jnp.sqrt(rsold) > atol, iiter < max_iter) x = jnp.full_like(b, fill_value=0.0) r = b - A(x) p = r rsold = jnp.vdot(r, p) b, p, r, rsold, x, iiter = jax.lax.while_loop( cond_fun, body_fun, (b, p, r, rsold, x, iiter) ) return x, iiter def solve_quasistatic(applied_strain_percent: float): prestrain = 0.1 nu = 0.35 E = 106e3 # N/m^2 lmbda = nu * E / ((1 + nu) * (1 - 2 * nu)) mu = E / (2 * (1 + nu)) Gamma = 15 # J/m^2 sigma_c = 20e3 # N/m^2 sigma_inf = prestrain * E / (1 - nu**2) L_G = 2 * mu * Gamma / (jnp.pi * (1 - nu) * sigma_inf**2) Nx = 100 # Number of elements in X Ny = 20 # Number of elements in Y Lx = 20 * L_G # Length in X Ly = 4 * L_G # Length in Y crack_length = 1.0 * L_G # function identifies nodes on the cohesive line at y = 0. and x > 2.0 def cohesive_line(coord: Array, tol: float) -> bool: return jnp.logical_and( jnp.isclose(coord[1], 0.0, atol=tol), coord[0] > crack_length ) upper_coords, upper_elements_2d, top_interface_elements = ( generate_mesh_with_line_elements( nx=Nx, ny=Ny, lxs=(0, Lx), lys=(0, Ly), curve_func=cohesive_line ) ) lower_coords, lower_elements_2d, bottom_interface_elements = ( generate_mesh_with_line_elements( nx=Nx, ny=Ny, lxs=(0, Lx), lys=(-Ly, 0), curve_func=cohesive_line ) ) coords = jnp.vstack((upper_coords, lower_coords)) elements = jnp.vstack( (upper_elements_2d, lower_elements_2d + upper_coords.shape[0]) ) mesh = Mesh(coords, elements) n_nodes = mesh.coords.shape[0] n_dofs_per_node = 2 n_dofs = n_dofs_per_node * n_nodes bottom_interface_elements = bottom_interface_elements + upper_coords.shape[0] mat = Material(mu=mu, lmbda=lmbda) tri = element.Tri3() op = Operator(mesh, tri) @jax.jit def total_strain_energy(u_flat): u = u_flat.reshape(-1, n_dofs_per_node) u_grad = op.grad(u) energy_density = strain_energy(u_grad, mat.mu, mat.lmbda) return op.integrate(energy_density) interface_mesh = extract_interface_mesh(mesh, bottom_interface_elements) line = element.Line2() line_op = Operator(interface_mesh, line) cohesive_mat = CohesiveMaterial(Gamma=Gamma, sigma_c=sigma_c, penalty=1e3) bottom_cohesive_nodes, top_cohesive_nodes = get_cohesive_nodes( mesh, bottom_interface_elements, top_interface_elements ) @jax.jit def total_cohesive_energy(u_flat: Array) -> float: u = u_flat.reshape(-1, n_dofs_per_node) jump = u.at[top_cohesive_nodes, :].get() - u.at[bottom_cohesive_nodes, :].get() jump_quad = line_op.eval(jump) cohesive_energy_density = exponential_cohesive_energy( jump_quad, cohesive_mat.Gamma, cohesive_mat.sigma_c, cohesive_mat.penalty ) return line_op.integrate(cohesive_energy_density) @jax.jit def total_energy(u_flat: Array) -> float: u = u_flat.reshape(-1, n_dofs_per_node) elastic_strain_energy = total_strain_energy(u) cohesive_energy = total_cohesive_energy(u) return elastic_strain_energy + cohesive_energy y_max = jnp.max(mesh.coords[:, 1]) y_min = jnp.min(mesh.coords[:, 1]) x_min = jnp.min(mesh.coords[:, 0]) height = y_max - y_min upper_nodes = jnp.where(jnp.isclose(mesh.coords[:, 1], y_max))[0] lower_nodes = jnp.where(jnp.isclose(mesh.coords[:, 1], y_min))[0] left_nodes = jnp.where(jnp.isclose(mesh.coords[:, 0], x_min))[0] fixed_dofs = jnp.concatenate( [ 2 * upper_nodes, 2 * upper_nodes + 1, 2 * lower_nodes, 2 * lower_nodes + 1, 2 * left_nodes, ] ) applied_disp = applied_strain_percent * prestrain * height prescribed_values = jnp.zeros(n_dofs).at[2 * upper_nodes].set(0.0) prescribed_values = prescribed_values.at[2 * upper_nodes + 1].set(applied_disp / 2.0) prescribed_values = prescribed_values.at[2 * lower_nodes].set(0.0) prescribed_values = prescribed_values.at[2 * lower_nodes + 1].set(-applied_disp / 2.0) free_dofs = jnp.setdiff1d(jnp.arange(n_dofs), fixed_dofs) # creating functions to compute the gradient and gradient = jax.jacrev(total_energy) # create a function to compute the JVP product @jax.jit def compute_tangent(du, u_prev): du_projected = du.at[fixed_dofs].set(0) tangent = jax.jvp(gradient, (u_prev,), (du_projected,))[1] tangent = tangent.at[fixed_dofs].set(0) return tangent def newton_krylov_solver( u, fext, gradient, fixed_dofs, ): fint = gradient(u) du = jnp.zeros_like(u) iiter = 0 norm_res = 1.0 tol = 1e-8 max_iter = 80 while norm_res > tol and iiter < max_iter: residual = fext - fint residual = residual.at[fixed_dofs].set(0) A = eqx.Partial(compute_tangent, u_prev=u) du, cg_iiter = conjugate_gradient(A=A, b=residual, atol=1e-8, max_iter=100) u = u.at[:].add(du) fint = gradient(u) residual = fext - fint residual = residual.at[fixed_dofs].set(0) norm_res = jnp.linalg.norm(residual) print(f" Residual: {norm_res:.2e}") iiter += 1 return u, norm_res u_prev = jnp.zeros(n_dofs) fext = jnp.zeros(n_dofs) n_steps = 100 force_on_top = [] displacement_on_top = [] u_per_step = [] energies = {} energies["elastic"] = [] energies["cohesive"] = [] force_on_top.append(0) displacement_on_top.append(0) u_per_step.append(u_prev.reshape(n_nodes, n_dofs_per_node)) energies["elastic"].append( total_strain_energy(u_prev.reshape(n_nodes, n_dofs_per_node)) ) energies["cohesive"].append( total_cohesive_energy(u_prev.reshape(n_nodes, n_dofs_per_node)) ) du_total = prescribed_values / n_steps # displacement increment for step in range(n_steps + 50): print(f"Step {step+1}/{n_steps}") if step < n_steps: u_prev = u_prev.at[fixed_dofs].add(du_total[fixed_dofs]) u_new, rnorm = newton_krylov_solver( u_prev, fext, gradient, fixed_dofs, ) u_prev = u_new force_on_top.append(jnp.sum(gradient(u_prev)[2 * upper_nodes + 1])) displacement_on_top.append(jnp.mean(u_prev[2 * upper_nodes + 1])) u_per_step.append(u_prev.reshape(n_nodes, n_dofs_per_node)) energies["elastic"].append( total_strain_energy(u_prev.reshape(n_nodes, n_dofs_per_node)) ) energies["cohesive"].append( total_cohesive_energy(u_prev.reshape(n_nodes, n_dofs_per_node)) ) u_solution = u_prev.reshape(n_nodes, n_dofs_per_node) return u_solution, energies, force_on_top, displacement_on_top