Source code for underworld3.meshing.faults

"""
Fault Surface Module for Underworld3.

This module provides classes for representing and manipulating fault surfaces
within 3D meshes. Faults are represented as triangulated 2D manifolds that can
be used for:

- Computing distance fields from mesh points to fault surfaces
- Transferring fault orientations (normals) to mesh variables
- Applying anisotropic rheology in fault zones

The workflow typically involves:
1. Creating FaultSurface objects from point clouds or VTK files
2. Triangulating point clouds using pyvista (optional dependency)
3. Collecting faults into a FaultCollection
4. Computing distance fields and transferring normals to mesh variables
5. Using the data with TransverseIsotropicFlowModel for fault-weakened rheology

Example:
    >>> # Load faults from VTK files
    >>> faults = uw.meshing.FaultCollection()
    >>> faults.add_from_vtk("fault1.vtk")
    >>> faults.add_from_vtk("fault2.vtk")
    >>>
    >>> # Compute distance field
    >>> fault_distance = faults.compute_distance_field(mesh)
    >>>
    >>> # Transfer normals to mesh
    >>> fault_normals = faults.transfer_normals(mesh)
    >>>
    >>> # Apply to rheology
    >>> stokes.constitutive_model.Parameters.director = fault_normals.sym

Notes:
    - pyvista is required for triangulation and distance computation
    - All pyvista operations run redundantly on each MPI rank
    - VTK files can be loaded/saved without pyvista
"""

from __future__ import annotations

import os
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import numpy as np
import sympy

import underworld3 as uw
from underworld3 import mpi

if TYPE_CHECKING:
    from underworld3.discretisation import Mesh, MeshVariable


def _require_pyvista():
    """Check pyvista availability with helpful error message."""
    try:
        import pyvista

        return pyvista
    except ImportError:
        raise ImportError(
            "Fault triangulation and distance computation require pyvista. "
            "Install with: pixi install -e runtime"
        )


[docs] class FaultSurface: """A triangulated fault surface with orientation data. Represents a single fault segment as a 2D surface embedded in 3D space. Can be created from: - A point cloud (requires triangulation via pyvista) - A VTK file (pre-triangulated surface) Attributes: name: Identifier for this fault segment points: (N, 3) array of surface points triangles: (M, 3) array of triangle vertex indices (after triangulation) normals: (M, 3) array of face normals (after triangulation) pv_mesh: PyVista PolyData object (if pyvista available) is_triangulated: Whether the surface has been triangulated Example: >>> # Create from points and triangulate >>> points = np.array([[0, 0, 0], [1, 0, 0], [0.5, 1, 0], [0.5, 0.5, 1]]) >>> fault = uw.meshing.FaultSurface("fault1", points) >>> fault.triangulate() >>> fault.to_vtk("fault1.vtk") >>> >>> # Load from VTK >>> fault2 = uw.meshing.FaultSurface.from_vtk("fault1.vtk") """
[docs] def __init__(self, name: str, points: np.ndarray = None): """Create a fault surface. Args: name: Identifier for this fault segment. points: (N, 3) array of 3D points, or None for an empty fault. """ self.name = name self._points = None self._triangles = None self._normals = None self._pv_mesh = None self._kdtree = None if points is not None: self.points = points
@property def points(self) -> Optional[np.ndarray]: """(N, 3) array of surface points.""" return self._points @points.setter def points(self, value: np.ndarray): """Set points and invalidate cached data.""" if value is not None: value = np.asarray(value) if value.ndim != 2 or value.shape[1] != 3: raise ValueError( f"Points must be (N, 3) array, got shape {value.shape}" ) self._points = value # Invalidate derived data self._triangles = None self._normals = None self._pv_mesh = None self._kdtree = None @property def triangles(self) -> Optional[np.ndarray]: """(M, 3) array of triangle vertex indices.""" return self._triangles @property def normals(self) -> Optional[np.ndarray]: """(M, 3) array of face normals.""" return self._normals @property def pv_mesh(self): """PyVista PolyData mesh (None if not triangulated or pyvista unavailable).""" return self._pv_mesh @property def is_triangulated(self) -> bool: """Whether the surface has been triangulated.""" return self._triangles is not None and self._normals is not None @property def n_points(self) -> int: """Number of points in the surface.""" return 0 if self._points is None else self._points.shape[0] @property def n_triangles(self) -> int: """Number of triangles in the surface.""" return 0 if self._triangles is None else self._triangles.shape[0]
[docs] @classmethod def from_vtk(cls, filename: str, name: str = None) -> "FaultSurface": """Load fault surface from VTK file. Args: filename: Path to VTK file (.vtk or .vtp) name: Name for the fault. If None, uses filename stem. Returns: FaultSurface: Loaded fault surface with triangulation and normals Raises: FileNotFoundError: If file doesn't exist ImportError: If pyvista not available """ pv = _require_pyvista() filepath = Path(filename) if not filepath.exists(): raise FileNotFoundError(f"VTK file not found: {filename}") if name is None: name = filepath.stem fault = cls(name) # Load the VTK file mesh = pv.read(str(filepath)) fault._pv_mesh = mesh fault._points = np.array(mesh.points) # Extract triangles from faces if mesh.n_cells > 0: # VTK faces format: [n_verts, v0, v1, v2, n_verts, v0, v1, v2, ...] faces = mesh.faces if len(faces) > 0: # Reshape to extract triangles (assumes all triangles) n_faces = mesh.n_cells fault._triangles = faces.reshape(-1, 4)[:, 1:4] # Extract or compute normals if "Normals" in mesh.cell_data: fault._normals = np.array(mesh.cell_data["Normals"]) elif mesh.n_cells > 0: mesh.compute_normals(inplace=True) fault._normals = np.array(mesh.cell_data["Normals"]) return fault
[docs] def triangulate(self, offset: float = 0.01) -> None: """Triangulate point cloud using pyvista delaunay_2d. This creates a triangulated surface from the point cloud by projecting points onto a best-fit plane, performing 2D Delaunay triangulation, and mapping back to 3D. Args: offset: Height offset for delaunay_2d (controls curvature tolerance). Larger values allow more curved surfaces. Raises: ImportError: If pyvista not available ValueError: If points too sparse for triangulation (< 3 points) RuntimeError: If triangulation fails """ pv = _require_pyvista() if self._points is None or self.n_points == 0: raise ValueError(f"Fault '{self.name}' has no points to triangulate") if self.n_points < 3: raise ValueError( f"Fault '{self.name}' has only {self.n_points} points. " "Need at least 3 points for triangulation." ) # Check for degenerate cases (all points nearly collinear) if self.n_points >= 3: # Compute bounding box extent extents = self._points.max(axis=0) - self._points.min(axis=0) sorted_extents = np.sort(extents) # If smallest extent is negligible compared to largest, points may be collinear if sorted_extents[0] < 1e-10 * sorted_extents[2] and sorted_extents[1] < 1e-10 * sorted_extents[2]: raise ValueError( f"Fault '{self.name}' points appear to be nearly collinear. " "Cannot create a 2D surface from a 1D line." ) # Create PolyData from points and triangulate # This runs on all ranks redundantly (pyvista doesn't work in parallel) polydata = pv.PolyData(self._points) self._pv_mesh = polydata.delaunay_2d(offset=offset) if self._pv_mesh.n_cells == 0: raise RuntimeError( f"Triangulation failed for fault '{self.name}'. " "Try adjusting the offset parameter or check point distribution." ) # Compute normals self._pv_mesh.compute_normals(inplace=True) # Extract numpy arrays faces = self._pv_mesh.faces self._triangles = faces.reshape(-1, 4)[:, 1:4] self._normals = np.array(self._pv_mesh.cell_data["Normals"])
[docs] def compute_normals(self, consistent_normals: bool = True) -> None: """Recompute face normals for triangulated surface. Args: consistent_normals: If True, attempt to make normals consistently oriented """ if not self.is_triangulated: raise RuntimeError( f"Fault '{self.name}' must be triangulated before computing normals" ) pv = _require_pyvista() if self._pv_mesh is not None: self._pv_mesh.compute_normals( inplace=True, consistent_normals=consistent_normals ) self._normals = np.array(self._pv_mesh.cell_data["Normals"])
[docs] def flip_normals(self) -> None: """Flip the direction of all face normals.""" if self._normals is not None: self._normals = -self._normals if self._pv_mesh is not None: self._pv_mesh.cell_data["Normals"] = self._normals
[docs] def to_vtk(self, filename: str) -> None: """Export triangulated surface to VTK file. Args: filename: Output path (.vtk or .vtp) Raises: RuntimeError: If surface not triangulated ImportError: If pyvista not available """ pv = _require_pyvista() if not self.is_triangulated: raise RuntimeError( f"Fault '{self.name}' must be triangulated before saving to VTK" ) # Ensure we have a pyvista mesh if self._pv_mesh is None: # Reconstruct from arrays faces = np.column_stack([ np.full(self.n_triangles, 3), self._triangles ]).flatten() self._pv_mesh = pv.PolyData(self._points, faces) self._pv_mesh.cell_data["Normals"] = self._normals self._pv_mesh.save(str(filename))
[docs] def build_kdtree(self) -> "uw.kdtree.KDTree": """Build KDTree for nearest-neighbor queries on face centers. Returns: KDTree built from triangle centroids Raises: RuntimeError: If surface not triangulated """ if not self.is_triangulated: raise RuntimeError( f"Fault '{self.name}' must be triangulated before building KDTree" ) if self._kdtree is None: # Compute triangle centroids centroids = self.face_centers self._kdtree = uw.kdtree.KDTree(centroids) return self._kdtree
@property def face_centers(self) -> np.ndarray: """(M, 3) array of triangle centroids.""" if not self.is_triangulated: raise RuntimeError( f"Fault '{self.name}' must be triangulated to get face centers" ) if self._pv_mesh is not None: return np.array(self._pv_mesh.cell_centers().points) else: # Compute manually from triangles v0 = self._points[self._triangles[:, 0]] v1 = self._points[self._triangles[:, 1]] v2 = self._points[self._triangles[:, 2]] return (v0 + v1 + v2) / 3.0 def __repr__(self) -> str: status = "triangulated" if self.is_triangulated else "not triangulated" return ( f"FaultSurface(name='{self.name}', " f"n_points={self.n_points}, " f"n_triangles={self.n_triangles}, " f"status={status})" )
[docs] class FaultCollection: """Collection of fault surfaces for mesh integration. Manages multiple FaultSurface objects and provides methods to: - Compute minimum distance from mesh points to any fault - Transfer fault normals to mesh variables via nearest-neighbor - Create rheology functions for fault-weakened zones Example: >>> faults = uw.meshing.FaultCollection() >>> faults.add_from_vtk("fault1.vtk") >>> faults.add_from_vtk("fault2.vtk") >>> >>> # Compute distance field >>> fault_distance = faults.compute_distance_field(mesh) >>> >>> # Transfer normals >>> fault_normals = faults.transfer_normals(mesh) >>> >>> # Create weakness function for rheology >>> eta_weak = faults.create_weakness_function( ... fault_distance, ... fault_width=mesh.get_min_radius() * 5, ... eta_weak=0.01, ... ) """
[docs] def __init__(self): """Create an empty fault collection.""" self.faults: Dict[str, FaultSurface] = {}
[docs] def add(self, fault: FaultSurface) -> None: """Add a fault surface to the collection. Args: fault: FaultSurface to add Raises: ValueError: If fault with same name already exists """ if fault.name in self.faults: raise ValueError( f"Fault '{fault.name}' already exists in collection. " "Use a different name or remove the existing fault." ) self.faults[fault.name] = fault
[docs] def add_from_vtk(self, filename: str, name: str = None) -> FaultSurface: """Load and add a fault from VTK file. Args: filename: Path to VTK file name: Name for the fault. If None, uses filename stem. Returns: The loaded FaultSurface """ fault = FaultSurface.from_vtk(filename, name) self.add(fault) return fault
[docs] def remove(self, name: str) -> FaultSurface: """Remove and return a fault from the collection. Args: name: Name of fault to remove Returns: The removed FaultSurface Raises: KeyError: If fault not found """ return self.faults.pop(name)
def __getitem__(self, name: str) -> FaultSurface: """Get a fault by name.""" return self.faults[name] def __iter__(self): """Iterate over fault names.""" return iter(self.faults) def __len__(self): """Number of faults in collection.""" return len(self.faults) @property def names(self) -> List[str]: """List of fault names.""" return list(self.faults.keys())
[docs] def compute_distance_field( self, mesh: "Mesh", distance_var: "MeshVariable" = None, variable_name: str = "fault_distance", ) -> "MeshVariable": """Compute minimum distance from mesh points to any fault surface. Uses pyvista's compute_implicit_distance for accurate signed distance computation. The returned field contains the absolute distance to the nearest fault surface at each mesh point. Parameters ---------- mesh : Mesh The mesh to compute distances on. distance_var : MeshVariable, optional Existing MeshVariable to store results. If None, creates a new variable. variable_name : str, optional Name for new variable if distance_var is None. Returns ------- MeshVariable Scalar variable with distance values (1 component). Raises ------ ValueError If collection is empty or no faults are triangulated. ImportError If pyvista is not available. """ pv = _require_pyvista() if len(self.faults) == 0: raise ValueError("Cannot compute distance field: no faults in collection") # Check all faults are triangulated for name, fault in self.faults.items(): if not fault.is_triangulated: raise ValueError( f"Fault '{name}' must be triangulated before computing distances" ) # Create or validate output variable if distance_var is None: distance_var = uw.discretisation.MeshVariable( variable_name, mesh, 1, degree=mesh.degree ) # Get mesh coordinates and create pyvista point cloud # (avoids visualisation module which initializes trame) coords = mesh.X.coords if hasattr(coords, 'magnitude'): coords = coords.magnitude elif hasattr(coords, '__array__'): coords = np.asarray(coords) pv_mesh = pv.PolyData(coords) # Initialize with large distance with uw.synchronised_array_update(): distance_var.data[:, 0] = 1e10 # Compute distance to each fault, take minimum for fault in self.faults.values(): dist_result = pv_mesh.compute_implicit_distance(fault.pv_mesh) fault_dist = np.abs(dist_result.point_data["implicit_distance"]) distance_var.data[:, 0] = np.minimum( distance_var.data[:, 0], fault_dist ) return distance_var
[docs] def transfer_normals( self, mesh: "Mesh", coords: np.ndarray = None, normal_var: "MeshVariable" = None, variable_name: str = "fault_normals", ) -> "MeshVariable": """Transfer fault normals to mesh points via nearest-neighbor lookup. For each mesh point, finds the closest fault face (from any fault in the collection) and copies that face's normal vector. Parameters ---------- mesh : Mesh The mesh to transfer normals to. coords : ndarray, optional Coordinates to query. If None, uses mesh.X.coords. normal_var : MeshVariable, optional Existing MeshVariable to store results. If None, creates a new variable. variable_name : str, optional Name for new variable if normal_var is None. Returns ------- MeshVariable Variable with normal vectors (3 components). Raises ------ ValueError If collection is empty or no faults are triangulated. """ if len(self.faults) == 0: raise ValueError("Cannot transfer normals: no faults in collection") # Check all faults are triangulated for name, fault in self.faults.items(): if not fault.is_triangulated: raise ValueError( f"Fault '{name}' must be triangulated before transferring normals" ) # Get query coordinates if coords is None: coords = mesh.X.coords # Handle UnitAwareArray by extracting raw values if hasattr(coords, 'magnitude'): coords = coords.magnitude elif hasattr(coords, '__array__'): coords = np.asarray(coords) # Create or validate output variable if normal_var is None: normal_var = uw.discretisation.MeshVariable( variable_name, mesh, 3, degree=mesh.degree ) # Build combined arrays of all fault face centers and normals all_centers = [] all_normals = [] for fault in self.faults.values(): all_centers.append(fault.face_centers) all_normals.append(fault.normals) combined_centers = np.vstack(all_centers) combined_normals = np.vstack(all_normals) # Build KDTree and query kdtree = uw.kdtree.KDTree(combined_centers) _, closest_idx = kdtree.query(coords) # Transfer normals with uw.synchronised_array_update(): normal_var.data[:] = combined_normals[closest_idx.flatten()] return normal_var
[docs] def create_weakness_function( self, distance_var: "MeshVariable", fault_width: float, eta_weak: float = 0.01, eta_background: float = 1.0, ) -> sympy.Expr: """Create Piecewise viscosity function for fault weakness. Creates a sympy Piecewise expression that gives: - eta_weak when distance < fault_width - eta_background otherwise This can be used directly with TransverseIsotropicFlowModel.Parameters.eta_1 for creating anisotropic weakness along fault zones. Args: distance_var: MeshVariable containing fault distances fault_width: Width of the weak zone around faults eta_weak: Viscosity within fault zone (default 0.01) eta_background: Viscosity outside fault zone (default 1.0) Returns: sympy.Piecewise expression for use in constitutive models Example: >>> eta_1 = faults.create_weakness_function( ... fault_distance, ... fault_width=mesh.get_min_radius() * 5, ... eta_weak=0.01, ... ) >>> stokes.constitutive_model.Parameters.eta_1 = eta_1 """ return sympy.Piecewise( (eta_weak, distance_var.sym[0] < fault_width), (eta_background, True), )
def __repr__(self) -> str: fault_strs = [f" {name}: {fault}" for name, fault in self.faults.items()] faults_repr = "\n".join(fault_strs) if fault_strs else " (empty)" return f"FaultCollection(\n{faults_repr}\n)"