Source code for underworld3.utilities.create_dmplex_from_medit

#!/usr/bin/env python
# coding: utf-8

from __future__ import print_function
import sys, petsc4py

petsc4py.init(sys.argv)
from petsc4py import PETSc
import numpy as np
import underworld3 as uw
from mpi4py import MPI

from underworld3.utilities.read_medit_ascii import read_medit_ascii
from underworld3.utilities.read_medit_ascii import print_medit_mesh_info


[docs] def create_dmplex_from_medit(medit_file, print_info=False): """ Reads a medit (.mesh) file and returns a distributed DMPlex with labels. """ if print_info: print_medit_mesh_info(medit_file) # Read mesh data vertices, _ = read_medit_ascii(medit_file, "Vertices") cells, cells_indx = read_medit_ascii(medit_file, "Tetrahedra") triangles, triangles_indx = read_medit_ascii(medit_file, "Triangles") edges, edges_indx = read_medit_ascii(medit_file, "Edges") dim = vertices.shape[1] plex = PETSc.DMPlex().createFromCellList(dim, cells, vertices) # Label Tetrahedra plex.createLabel("TetraLabels") tet_Start, tet_End = plex.getDepthStratum(3) for i in range(tet_Start, tet_End): plex.setLabelValue("TetraLabels", i, 400 + cells_indx[i]) # Define boundary labels vertices_labels_values = ( (("LeftPts", 101), ("RightPts", 102)), (("FrontPts", 103), ("BackPts", 104)), (("BottomPts", 105), ("TopPts", 106)), ) def label_vertices(dm, labels_values, dim=3, atol=1e-6): """ Efficiently label vertices in a distributed way. """ pStart, pEnd = dm.getDepthStratum(0) coords = dm.getCoordinates().array.reshape(pEnd - pStart, dim) # PETSc requires all ranks to call getBoundingBox() uw.mpi.comm.barrier() bb_values = dm.getBoundingBox() for j in range(dim): for k in range(2): dm.createLabel(labels_values[j][k][0]) # Label points for i, point in enumerate(range(pStart, pEnd)): coord = coords[i] for j in range(dim): for k in range(2): if np.isclose(coord[j], bb_values[j][k], atol=atol): label, value = labels_values[j][k] dm.setLabelValue(label, point, value) # Label vertices label_vertices(plex, vertices_labels_values) # Fetch Label Indices Once (Cache them) def get_label_indices(dm, label_name, label_value): label = dm.getLabel(label_name) if label is None: return None indices = dm.getStratumIS(label_name, label_value) return indices.getIndices() if indices else None # Cache face and vertex labels to avoid repeated function calls label_cache = {} for dim_labels in vertices_labels_values: for name, value in dim_labels: indices = get_label_indices(plex, name, value) - tet_End label_cache[(name, value)] = indices # Label Triangles face_labels_values = ( (("Left", 391), ("Right", 392)), (("Front", 393), ("Back", 394)), (("Bottom", 395), ("Top", 396)), ) plex.createLabel("TriangleLabels") for j in range(dim): for k in range(2): plex.createLabel(face_labels_values[j][k][0]) sorted_triangles = np.sort(triangles, axis=1) tri_Start, tri_End = plex.getDepthStratum(2) for i in range(tri_Start, tri_End): coneclose, _ = plex.getTransitiveClosure(i) conepoints = np.sort(coneclose[-3:] - tet_End) if np.any(np.all(sorted_triangles == conepoints, axis=1)): index_tri = np.where(np.all(sorted_triangles == conepoints, axis=1))[0][ 0 ] # fetch the where conepoints matches in sorted_triangles list plex.setLabelValue("TriangleLabels", i, 300 + triangles_indx[index_tri]) # boundary labeling using cached indices for j in range(dim): for k in range(2): vertices_label, vertices_value = vertices_labels_values[j][k] face_label, face_value = face_labels_values[j][k] vertices_indices = label_cache.get((vertices_label, vertices_value), []) if ( conepoints[0] in vertices_indices and conepoints[1] in vertices_indices and conepoints[2] in vertices_indices ): plex.setLabelValue(face_label, i, face_value) # Label Edges sorted_edges = np.sort(edges, axis=1) edge_Start, edge_End = plex.getDepthStratum(1) plex.createLabel("LineLabels") for i in range(edge_Start, edge_End): coneclose, _ = plex.getTransitiveClosure(i) conepoints = np.sort(coneclose[-2:] - tet_End) if np.any(np.all(sorted_edges == conepoints, axis=1)): index_edge = np.where(np.all(sorted_edges == conepoints, axis=1))[0][ 0 ] # fetch the where conepoints matches in sorted_edges list plex.setLabelValue("LineLabels", i, 200 + edges_indx[index_edge]) return plex