"""PyVista visualization helper routines for Underworld3.
This module provides functions to convert Underworld meshes, swarms,
and variables to PyVista objects for interactive 3D visualization.
"""
import os
import underworld3 as uw
def initialise(jupyter_backend):
"""Initialize PyVista with Underworld-friendly defaults.
Sets up PyVista global theme with white background, anti-aliasing,
and appropriate Jupyter backend.
Parameters
----------
jupyter_backend : str or None
Jupyter backend to use ('trame', 'client', 'panel', etc.).
If None, auto-detects based on environment.
"""
import pyvista as pv
pv.global_theme.background = "white"
pv.global_theme.anti_aliasing = "msaa"
pv.global_theme.smooth_shading = True
pv.global_theme.camera["viewup"] = [0.0, 1.0, 0.0]
pv.global_theme.camera["position"] = [0.0, 0.0, 5.0]
# Check if we're in a remote Jupyter environment (binder, JupyterHub, etc.)
is_remote = (
"BINDER_LAUNCH_HOST" in os.environ
or "BINDER_REPO_URL" in os.environ
or "JUPYTERHUB_SERVICE_PREFIX" in os.environ
)
try:
if jupyter_backend is not None:
pv.global_theme.jupyter_backend = jupyter_backend
else:
# Use trame backend for all Jupyter environments
pv.global_theme.jupyter_backend = "trame"
# Configure trame for remote environments (binder, JupyterHub)
# Enable server proxy but let PyVista auto-detect the prefix
if is_remote:
pv.global_theme.trame.server_proxy_enabled = True
# Don't set prefix - let PyVista/trame auto-detect from environment
except RuntimeError:
pv.global_theme.jupyter_backend = "panel"
return
[docs]
def mesh_to_pv_mesh(mesh, jupyter_backend=None):
"""Convert Underworld mesh to PyVista unstructured grid.
Parameters
----------
mesh : Mesh
Underworld mesh to convert.
jupyter_backend : str, optional
PyVista Jupyter backend to use.
Returns
-------
pyvista.UnstructuredGrid
PyVista mesh with unit metadata attached.
"""
# # Required in notebooks
# import nest_asyncio
# nest_asyncio.apply()
initialise(jupyter_backend)
# import os
# import shutil
# import tempfile
import pyvista as pv
import numpy as np
# with tempfile.TemporaryDirectory() as tmp:
# if type(mesh) == str: # reading msh file directly
# vtk_filename = os.path.join(tmp, "tmpMsh.msh")
# shutil.copyfile(mesh, vtk_filename)
# else: # reading mesh by creating vtk
# vtk_filename = os.path.join(tmp, "tmpMsh.vtk")
# mesh.vtk(vtk_filename)
# pvmesh = pv.read(vtk_filename)
# return pvmesh
## Alternative - not via file / create an unstructured grid in pyvista
from petsc4py import PETSc
cStart, cEnd = mesh.dm.getHeightStratum(0)
fStart, fEnd = mesh.dm.getHeightStratum(1)
pStart, pEnd = mesh.dm.getDepthStratum(0)
cell_num_points = mesh.element.entities[mesh.dim]
face_num_points = mesh.element.face_entities[mesh.dim]
cell_points_list = []
for cell_id in range(cStart, cEnd):
cell_points = mesh.dm.getTransitiveClosure(cell_id)[0][-cell_num_points:]
cell_points_list.append(cell_points - pStart)
# PETSc DMPlex uses opposite winding to VTK for 3D cells.
# Reorder vertices to match VTK convention (right-handed orientation)
# so that face normals, shading, clipping, and back-face culling work correctly.
#
# Hexahedra: PETSc winds bottom face CCW from above, VTK expects CW from above.
# Fix: swap vertices 1 and 3 on the bottom face → [0,3,2,1,4,5,6,7]
# Tetrahedra: PETSc uses left-handed orientation, VTK expects right-handed.
# Fix: swap vertices 1 and 2 → [0,2,1,3]
if mesh.dim == 3:
if mesh.dm.isSimplex():
tet_reorder = [0, 2, 1, 3]
cell_points_list = [pts[tet_reorder] for pts in cell_points_list]
else:
hex_reorder = [0, 3, 2, 1, 4, 5, 6, 7]
cell_points_list = [pts[hex_reorder] for pts in cell_points_list]
try:
import meshio
match (mesh.dm.isSimplex(), mesh.dim):
case (True, 2):
meshio_cell_type = "triangle"
case (True, 3):
meshio_cell_type = "tetra"
case (False, 2):
meshio_cell_type = "quad"
case (False, 3):
meshio_cell_type = "hexahedron"
# Use non-dimensional [0-1] coordinates for PyVista
# PyVista only needs coordinates for spatial positioning (visualization)
# evaluate() expects non-dimensional coords to query PETSc KDTrees
mesh_coordinates_nd = np.asarray(mesh.X.coords, dtype=np.double)
# Store units metadata for labeling and axis annotation
mesh_units = mesh.units if mesh.units is not None else uw.units.dimensionless
mmesh = meshio.Mesh(
points=mesh_coordinates_nd, cells=[(meshio_cell_type, np.array(cell_points_list))]
)
pv_mesh = pv.from_meshio(mmesh)
pv_mesh._units = mesh_units
# Store original coordinate array for proper evaluation
pv_mesh._coord_array = mesh.X.coords
return pv_mesh
except ImportError:
match (mesh.dm.isSimplex(), mesh.dim):
case (True, 2):
vtk_cell_type = pv.cell.CellType.TRIANGLE
case (True, 3):
vtk_cell_type = pv.cell.CellType.TETRA
case (False, 2):
vtk_cell_type = pv.cell.CellType.QUAD
case (False, 3):
vtk_cell_type = pv.cell.CellType.HEXAHEDRON
cells_array = np.array(cell_points_list, dtype=int)
cells_size = np.full((cells_array.shape[0], 1), cell_num_points, dtype=int)
cells_type = np.full((cells_array.shape[0], 1), vtk_cell_type, dtype=int)
cells_array = np.hstack((cells_size, cells_array), dtype=int)
# Use non-dimensional [0-1] coordinates for PyVista (see meshio path above)
mesh_coordinates_nd = np.asarray(mesh.X.coords, dtype=np.double)
pv_mesh = pv.UnstructuredGrid(cells_array, cells_type, coords_to_pv_coords(mesh_coordinates_nd))
# Store units metadata for labeling
pv_mesh._units = mesh.units if mesh.units is not None else uw.units.dimensionless
# Store original coordinate array for proper evaluation
pv_mesh._coord_array = mesh.X.coords
return pv_mesh
def coords_to_pv_coords(coords):
"""Convert coordinate array to PyVista-compatible 3D coordinates.
Parameters
----------
coords : numpy.ndarray
Coordinate array of shape ``(n, 2)`` or ``(n, 3)``.
Returns
-------
numpy.ndarray
3D coordinate array of shape ``(n, 3)``.
"""
return _vector_to_pv_vector(coords)
def _vector_to_pv_vector(vector):
"""Convert numpy coordinate array to pyvista compatible array"""
import numpy as np
if vector.shape[1] == 3:
return vector
else:
vector3 = np.zeros((vector.shape[0], 3))
vector3[:, 0:2] = vector[:]
return vector3
[docs]
def swarm_to_pv_cloud(swarm):
"""Convert swarm particle positions to PyVista point cloud.
Parameters
----------
swarm : Swarm
Underworld swarm object.
Returns
-------
pyvista.PolyData
Point cloud with particle positions.
"""
import numpy as np
import pyvista as pv
points = np.zeros((swarm.local_size, 3))
points[:, 0] = swarm.data[:, 0]
points[:, 1] = swarm.data[:, 1]
if swarm.mesh.dim == 2:
points[:, 2] = 0.0
else:
points[:, 2] = swarm.data[:, 2]
point_cloud = pv.PolyData(points)
return point_cloud
[docs]
def meshVariable_to_pv_cloud(meshVar):
"""Convert mesh variable node positions to PyVista point cloud.
Parameters
----------
meshVar : MeshVariable
Underworld mesh variable.
Returns
-------
pyvista.PolyData
Point cloud at mesh variable nodal locations.
"""
import numpy as np
import pyvista as pv
import underworld3 as uw
# Get coordinates from the mesh variable
# These may be dimensional (UnitAwareArray with meters) when units are active
# The alpha parameter in meshVariable_to_pv_mesh_object is now computed from
# these coordinates directly, ensuring scale consistency
coords = np.asarray(meshVar.coords, dtype=np.double)
points = np.zeros((coords.shape[0], 3))
points[:, 0] = coords[:, 0]
points[:, 1] = coords[:, 1]
if meshVar.mesh.dim == 2:
points[:, 2] = 0.0
else:
points[:, 2] = coords[:, 2]
point_cloud = pv.PolyData(points)
# Store units metadata for labeling (same as mesh_to_pv_mesh)
point_cloud._units = meshVar.mesh.units if meshVar.mesh.units is not None else uw.units.dimensionless
# Store original coordinate array for proper evaluation
point_cloud._coord_array = meshVar.coords
return point_cloud
[docs]
def meshVariable_to_pv_mesh_object(meshVar, alpha=None):
"""Convert mesh variable to Delaunay-triangulated PyVista mesh.
Creates a mesh by triangulating the mesh variable's nodal points.
Useful for higher-order elements where the base mesh doesn't
capture all data points.
Parameters
----------
meshVar : MeshVariable
Underworld mesh variable.
alpha : float, optional
Alpha parameter for Delaunay triangulation. If None, computed
automatically from coordinate range.
Returns
-------
pyvista.UnstructuredGrid
Triangulated mesh through the variable's nodal points.
"""
import numpy as np
mesh = meshVar.mesh
dim = mesh.dim
point_cloud = meshVariable_to_pv_cloud(meshVar)
if alpha is None:
# Compute alpha from the point cloud coordinates themselves
# This ensures consistency between coordinate scale and alpha parameter
# mesh.get_max_radius() returns nondimensional values but coords may be
# dimensional when units are active - causing Delaunay to fail
points = point_cloud.points
coord_range = points.max() - points.min()
# Use a fraction of the coordinate range as a reasonable alpha
# This approximates the mesh element size
n_elements_estimate = max(10, len(points) ** (1.0 / dim)) # rough estimate
alpha = coord_range / n_elements_estimate * 2.0 # 2x for safety margin
if dim == 2:
pv_mesh = point_cloud.delaunay_2d(alpha=alpha)
else:
pv_mesh = point_cloud.delaunay_3d(alpha=alpha)
# Propagate metadata from point_cloud to the triangulated mesh
# PyVista's delaunay methods return a new object that doesn't preserve custom attributes
if hasattr(point_cloud, "_coord_array"):
pv_mesh._coord_array = point_cloud._coord_array
if hasattr(point_cloud, "_units"):
pv_mesh._units = point_cloud._units
return pv_mesh
[docs]
def scalar_fn_to_pv_points(pv_mesh, uw_fn, dim=None):
"""Evaluate Underworld scalar function at PyVista mesh points.
Parameters
----------
pv_mesh : pyvista.DataSet
PyVista mesh or point cloud to evaluate at.
uw_fn : sympy.Expr
Underworld scalar function to evaluate.
dim : int, optional
Dimensionality (2 or 3). Auto-detected if None.
Returns
-------
numpy.ndarray
Scalar values at mesh points (units stripped for PyVista).
The units string is stored as ``pv_mesh._last_scalar_units``.
"""
import underworld3 as uw
import numpy as np
if dim is None:
if pv_mesh.points[:, 2].max() - pv_mesh.points[:, 2].min() < 1.0e-6:
dim = 2
else:
dim = 3
# Use stored coordinate array if available (preserves units and dimensional info)
if hasattr(pv_mesh, '_coord_array'):
coords = pv_mesh._coord_array[:, 0:dim]
else:
coords = pv_mesh.points[:, 0:dim]
scalar_values = uw.function.evaluate(uw_fn, coords)
# Capture units before stripping for colorbar labels
scalar_units = None
if hasattr(scalar_values, "units") and scalar_values.units is not None:
scalar_units = str(scalar_values.units)
elif hasattr(scalar_values, "_units") and scalar_values._units is not None:
scalar_units = str(scalar_values._units)
pv_mesh._last_scalar_units = scalar_units
# Strip units for PyVista compatibility
if hasattr(scalar_values, "magnitude"):
scalar_values = scalar_values.magnitude
else:
scalar_values = np.asarray(scalar_values)
return scalar_values
[docs]
def vector_fn_to_pv_points(pv_mesh, uw_fn, dim=None):
"""Evaluate Underworld vector function at PyVista mesh points.
Parameters
----------
pv_mesh : pyvista.DataSet
PyVista mesh or point cloud to evaluate at.
uw_fn : sympy.Matrix
Underworld vector function to evaluate.
dim : int, optional
Dimensionality (not used, derived from function shape).
Returns
-------
numpy.ndarray
Vector values at mesh points, shape ``(n_points, 3)``.
Units string stored as ``pv_mesh._last_vector_units``.
"""
import numpy as np
import underworld3 as uw
dim = uw_fn.shape[1]
if dim != 2 and dim != 3:
print(f"UW vector function should have dimension 2 or 3")
if hasattr(pv_mesh, '_coord_array'):
coords = pv_mesh._coord_array[:, 0:dim]
else:
coords = pv_mesh.points[:, 0:dim]
vector_values_raw = uw.function.evaluate(uw_fn, coords)
# Capture units before stripping
vector_units = None
if hasattr(vector_values_raw, "units") and vector_values_raw.units is not None:
vector_units = str(vector_values_raw.units)
elif hasattr(vector_values_raw, "_units") and vector_values_raw._units is not None:
vector_units = str(vector_values_raw._units)
pv_mesh._last_vector_units = vector_units
if hasattr(vector_values_raw, "magnitude"):
vector_values_raw = vector_values_raw.magnitude
else:
vector_values_raw = np.asarray(vector_values_raw)
vector_values = np.zeros_like(pv_mesh.points)
vector_values[:, 0:dim] = vector_values_raw.squeeze()
return vector_values
# def vector_fn_to_pv_arrows(coords, uw_fn, dim=None):
# """evaluate uw vector function on point cloud"""
# import numpy as np
# dim = uw_fn.shape[1]
# if dim != 2 and dim != 3:
# print(f"UW vector function should have dimension 2 or 3")
# coords = pv_mesh.points[:, 0 : dim - 1]
# vector_values = np.zeros_like(coords)
# for i in range(0, dim):
# vector_values[:, i] = uw.function.evaluate(uw_fn[i], coords, evalf=True)
# return vector_values
def clip_mesh(pvmesh, clip_angle):
"""
Clip the given mesh using planes at the specified angle.
Parameters:
-----------
pvmesh : object
The PyVista mesh object to be clipped.
clip_angle : float
The angle (in degrees) at which to clip the mesh.
Returns:
--------
list
A list containing the two clipped mesh parts.
"""
import numpy as np
# Calculate normals for clipping planes
clip1_normal = (np.cos(np.deg2rad(clip_angle)), np.cos(np.deg2rad(clip_angle)), 0.0)
clip2_normal = (
np.cos(np.deg2rad(clip_angle)),
-np.cos(np.deg2rad(clip_angle)),
0.0,
)
# Perform clipping
clip1 = pvmesh.clip(origin=(0.0, 0.0, 0.0), normal=clip1_normal, invert=False, crinkle=False)
clip2 = pvmesh.clip(origin=(0.0, 0.0, 0.0), normal=clip2_normal, invert=False, crinkle=False)
return [clip1, clip2]
[docs]
def plot_mesh(
mesh,
title="",
clip_angle=0.0,
cpos="xy",
window_size=(750, 750),
show_edges=True,
save_png=False,
dir_fname="",
):
"""
Plot a mesh with optional clipping, edge display, and saving functionality.
Parameters:
-----------
mesh : object
The mesh object to be plotted. This should be in a format that can be converted
into a PyVista mesh using `vis.mesh_to_pv_mesh()`.
title : str, optional
The title text to be displayed on the plot. Default is an empty string, meaning no title is shown.
clip_angle : float, optional
The angle (in degrees) at which to clip the mesh. If set to 0.0, no clipping is applied.
Clipping is performed using planes at the specified angle. Default is 0.0.
cpos : str or list, optional
The camera position for viewing the mesh. It can be a string such as 'xy', 'xz', 'yz', or
a list specifying the exact camera position. Default is 'xy'.
window_size : tuple of int, optional
The size of the rendering window in pixels as (width, height). Default is (750, 750).
show_edges : bool, optional
Whether to display the edges of the mesh in the plot. If `True`, edges will be shown.
Default is `True`.
save_png : bool, optional
Whether to save the plot as a PNG file. If `True`, the plot will be saved to the specified
directory and filename. Default is `False`.
dir_fname : str, optional
The directory and filename for saving the PNG image if `save_png` is `True`.
If left empty, no file is saved. Default is an empty string.
Returns:
--------
None
This function does not return any value. It displays the mesh plot in a PyVista window
and optionally saves a screenshot.
"""
import pyvista as pv
pvmesh = mesh_to_pv_mesh(mesh)
pl = pv.Plotter(window_size=window_size)
if clip_angle != 0.0:
clipped_meshes = clip_mesh(pvmesh, clip_angle)
for clipped_mesh in clipped_meshes:
pl.add_mesh(clipped_mesh, edge_color="k", show_edges=True, opacity=1.0)
else:
pl.add_mesh(
pvmesh,
edge_color="k",
show_edges=show_edges,
use_transparency=False,
opacity=1.0,
)
if len(title) != 0:
pl.add_text(title, font_size=18, position=(950, 2100))
pl.show(cpos=cpos)
if save_png:
pl.camera.zoom(1.4)
pl.screenshot(dir_fname, scale=3.5)
return
[docs]
def plot_scalar(
mesh,
scalar,
scalar_name="",
cmap="",
clim="",
window_size=(750, 750),
title="",
fmt="%10.7f",
clip_angle=0.0,
cpos="xy",
show_edges=False,
save_png=False,
dir_fname="",
):
"""
Plot a scalar quantity from a mesh with options for clipping, colormap, and saving.
Parameters:
-----------
mesh : object
The mesh object to be plotted. This should be in a format that can be converted
into a PyVista mesh using `vis.mesh_to_pv_mesh()`.
scalar : mesh variable name or sympy expression
The scalar values associated with the mesh points. These values will be visualized
on the mesh.
scalar_name : str, optional
The name of the scalar field to be used when adding it to the mesh. This name will
also be used as the label for the scalar bar. Default is an empty string.
cmap : str, optional
The colormap to be used for visualizing the scalar values. This can be any colormap
recognized by PyVista or Matplotlib. Default is an empty string, which uses the default colormap.
clim : tuple of float, optional
The scalar range to be used for coloring the mesh (e.g., `(min_value, max_value)`). If not
provided, the range of the scalar values is used. Default is an empty string, which uses
the full range of the scalar values.
window_size : tuple of int, optional
The size of the rendering window in pixels as (width, height). Default is (750, 750).
title : str, optional
The title text to be displayed on the plot. Default is an empty string, meaning no title is shown.
fmt : str, optional
The format string for scalar values. This is typically used when displaying values on the scalar bar.
Default is '%10.7f'.
clip_angle : float, optional
The angle (in degrees) at which to clip the mesh. If set to 0.0, no clipping is applied.
Clipping is performed using planes at the specified angle. Default is 0.0.
cpos : str or list, optional
The camera position for viewing the mesh. It can be a string such as 'xy', 'xz', 'yz', or
a list specifying the exact camera position. Default is 'xy'.
show_edges : bool, optional
Whether to display the edges of the mesh in the plot. If `True`, edges will be shown.
Default is `False`.
save_png : bool, optional
Whether to save the plot as a PNG file. If `True`, the plot will be saved to the specified
directory and filename. Default is `False`.
dir_fname : str, optional
The directory and filename for saving the PNG image if `save_png` is `True`.
If left empty, no file is saved. Default is an empty string.
Returns:
--------
None
This function does not return any value. It displays the scalar field on the mesh in a PyVista
window and optionally saves a screenshot.
"""
import sympy
import numpy as np
import pyvista as pv
pvmesh = mesh_to_pv_mesh(mesh)
pvmesh.point_data[scalar_name] = scalar_fn_to_pv_points(pvmesh, scalar)
# Build scalar bar label with units if available
scalar_bar_title = scalar_name
if hasattr(pvmesh, '_last_scalar_units') and pvmesh._last_scalar_units:
scalar_bar_title = f"{scalar_name} ({pvmesh._last_scalar_units})"
pl = pv.Plotter(window_size=window_size)
if clip_angle != 0.0:
clipped_meshes = clip_mesh(pvmesh, clip_angle)
for clipped_mesh in clipped_meshes:
pl.add_mesh(
clipped_mesh,
cmap=cmap,
edge_color="k",
scalars=scalar_name,
show_edges=show_edges,
use_transparency=False,
show_scalar_bar=True,
scalar_bar_args={"title": scalar_bar_title},
opacity=1.0,
clim=clim,
)
else:
pl.add_mesh(
pvmesh,
cmap=cmap,
edge_color="k",
scalars=scalar_name,
show_edges=show_edges,
use_transparency=False,
opacity=1.0,
clim=clim,
show_scalar_bar=True,
scalar_bar_args={"title": scalar_bar_title},
)
pl.show(cpos=cpos)
if len(title) != 0:
pl.add_text(title, font_size=18, position=(950, 2100))
if save_png:
pl.camera.zoom(1.4)
pl.screenshot(dir_fname, scale=3.5)
return
[docs]
def plot_vector(
mesh,
vector,
vector_name="",
cmap="",
clim="",
vmag="",
vfreq="",
save_png=False,
dir_fname="",
title="",
fmt="%10.7f",
clip_angle=0.0,
show_arrows=False,
cpos="xy",
show_edges=False,
window_size=(750, 750),
scalar=None,
scalar_name="",
):
"""
Plot a vector quantity from a mesh with options for clipping, colormap, vector magnitude, and saving.
Parameters:
-----------
mesh : object
The mesh object to be plotted. This should be in a format that can be converted
into a PyVista mesh using `vis.mesh_to_pv_mesh()`.
vector : mesh variable name or sympy expression
The symbolic representation of the vector field associated with the mesh points.
This vector field will be visualized on the mesh.
vector_name : str, optional
The name of the vector field to be used when adding it to the mesh. This name will
also be used as the label for the vector magnitude in the scalar bar. Default is an empty string.
cmap : str, optional
The colormap to be used for visualizing the vector magnitudes. This can be any colormap
recognized by PyVista or Matplotlib. Default is an empty string, which uses the default colormap.
clim : tuple of float, optional
The scalar range to be used for coloring the mesh based on vector magnitudes (e.g., `(min_value, max_value)`).
If not provided, the range of the vector magnitudes is used. Default is an empty string.
vmag : float or str, optional
The scaling factor for the arrow magnitudes when plotting vectors as arrows.
Default is an empty string, which uses the default scaling.
vfreq : int, optional
The frequency of arrows to display when `show_arrows` is `True`. For example, if set to 10, every 10th vector
will be plotted as an arrow. Default is an empty string, which uses the default frequency.
save_png : bool, optional
Whether to save the plot as a PNG file. If `True`, the plot will be saved to the specified
directory and filename. Default is `False`.
dir_fname : str, optional
The directory and filename for saving the PNG image if `save_png` is `True`.
If left empty, no file is saved. Default is an empty string.
title : str, optional
The title text to be displayed on the plot. Default is an empty string, meaning no title is shown.
fmt : str, optional
The format string for scalar values, typically used in the scalar bar. Default is '%10.7f'.
clip_angle : float, optional
The angle (in degrees) at which to clip the mesh. If set to 0.0, no clipping is applied.
Clipping is performed using planes at the specified angle. Default is 0.0.
show_arrows : bool, optional
Whether to display arrows representing the vector field on the mesh. If `True`, arrows will be shown.
Default is `False`.
cpos : str or list, optional
The camera position for viewing the mesh. It can be a string such as 'xy', 'xz', 'yz', or
a list specifying the exact camera position. Default is 'xy'.
show_edges : bool, optional
Whether to display the edges of the mesh in the plot. If `True`, edges will be shown.
Default is `False`.
window_size : tuple of int, optional
The size of the rendering window in pixels as (width, height). Default is (750, 750).
scalar : mesh variable name or sympy expression, optional
An optional scalar field associated with the mesh points. If provided, this scalar field
will be used for coloring the mesh instead of the vector magnitude. Default is `None`.
scalar_name : str, optional
The name of the scalar field to be used when adding it to the mesh. This name will
be used as the label for the scalar bar if `scalar` is provided. Default is an empty string.
Returns:
--------
None
This function does not return any value. It displays the vector field on the mesh in a PyVista
window and optionally saves a screenshot.
Notes
-----
When the model uses physical units, arrows are drawn in the mesh
coordinate space. A velocity of 1 cm/yr on a mesh in meters
(extent ~1e6) produces arrows of length ~3e-10 in mesh units —
effectively invisible. Adjust ``vmag`` to compensate::
# Scale arrows to ~5% of mesh extent
vmag = 0.05 * mesh_extent / max_velocity
"""
import sympy
import numpy as np
import pyvista as pv
pvmesh = mesh_to_pv_mesh(mesh)
pvmesh.point_data[vector_name] = vector_fn_to_pv_points(pvmesh, vector.sym)
if scalar is None:
scalar_name = vector_name + "_mag"
pvmesh.point_data[scalar_name] = scalar_fn_to_pv_points(
pvmesh, sympy.sqrt(vector.sym.dot(vector.sym))
)
else:
pvmesh.point_data[scalar_name] = scalar_fn_to_pv_points(pvmesh, scalar.sym)
# Build scalar bar label with units if available
scalar_bar_title = scalar_name
if hasattr(pvmesh, '_last_scalar_units') and pvmesh._last_scalar_units:
scalar_bar_title = f"{scalar_name} ({pvmesh._last_scalar_units})"
velocity_points = meshVariable_to_pv_cloud(vector)
velocity_points.point_data[vector_name] = vector_fn_to_pv_points(velocity_points, vector.sym)
pl = pv.Plotter(window_size=window_size)
if clip_angle != 0.0:
clipped_meshes = clip_mesh(pvmesh, clip_angle)
for clipped_mesh in clipped_meshes:
pl.add_mesh(
clipped_mesh,
cmap=cmap,
edge_color="k",
scalars=scalar_name,
show_edges=show_edges,
use_transparency=False,
show_scalar_bar=True,
scalar_bar_args={"title": scalar_bar_title},
opacity=1.0,
clim=clim,
)
else:
pl.add_mesh(
pvmesh,
cmap=cmap,
edge_color="k",
scalars=scalar_name,
show_edges=show_edges,
use_transparency=False,
opacity=1.0,
clim=clim,
show_scalar_bar=True,
scalar_bar_args={"title": scalar_bar_title},
)
if show_arrows:
pl.add_arrows(
velocity_points.points[::vfreq],
velocity_points.point_data[vector_name][::vfreq],
mag=vmag,
color="k",
)
pl.show(cpos=cpos)
if len(title) != 0:
pl.add_text(title, font_size=18, position=(950, 1075))
if save_png:
pl.camera.zoom(1.4)
pl.screenshot(dir_fname, scale=3.5)
return
def save_colorbar(
colormap="",
cb_bounds=None,
vmin=None,
vmax=None,
figsize_cb=(6, 1),
primary_fs=18,
cb_orient="vertical",
cb_axis_label="",
cb_label_xpos=0.5,
cb_label_ypos=0.5,
fformat="png",
output_path="",
fname="",
):
"""
Save a colorbar separately from a plot with customizable appearance and format.
Parameters:
-----------
colormap : str, optional
The name of the colormap to be used for the colorbar. This should be a valid Matplotlib colormap name.
Default is an empty string, which uses the default colormap.
cb_bounds : list or array-like, optional
The bounds to be used for the colorbar. If provided, the colorbar will be generated with these bounds.
Default is None, which means bounds are not explicitly set.
vmin : float, optional
The minimum value for the colorbar. This is used to define the lower limit of the colormap.
Default is None.
vmax : float, optional
The maximum value for the colorbar. This is used to define the upper limit of the colormap.
Default is None.
figsize_cb : tuple of float, optional
The size of the figure for the colorbar in inches as (width, height). Default is (6, 1).
primary_fs : int, optional
The primary font size for the colorbar labels and title. Default is 18.
cb_orient : str, optional
The orientation of the colorbar, either 'vertical' or 'horizontal'. Default is 'vertical'.
cb_axis_label : str, optional
The label for the colorbar axis. This text will be displayed alongside the colorbar.
Default is an empty string.
cb_label_xpos : float, optional
The x-position for the colorbar label. This adjusts the horizontal positioning of the label.
Default is 0.5.
cb_label_ypos : float, optional
The y-position for the colorbar label. This adjusts the vertical positioning of the label.
Default is 0.5.
fformat : str, optional
The format for saving the colorbar image. Supported formats are 'png' and 'pdf'.
Default is 'png'.
output_path : str, optional
The directory path where the colorbar image will be saved. Default is an empty string.
fname : str, optional
The filename to use when saving the colorbar image. This should not include the file extension.
Default is an empty string.
Returns:
--------
None
This function does not return any value. It saves the colorbar as a separate image file in the specified format.
"""
import numpy as np
import matplotlib.pyplot as plt
plt.figure(figsize=figsize_cb)
plt.rc("font", size=primary_fs) # Set font size
if cb_bounds is not None:
bounds_np = np.array([cb_bounds])
img = plt.imshow(bounds_np, cmap=colormap)
else:
v_min_max_np = np.array([[vmin, vmax]])
img = plt.imshow(v_min_max_np, cmap=colormap)
plt.gca().set_visible(False)
if cb_orient == "vertical":
cax = plt.axes([0.1, 0.2, 0.06, 1.15])
cb = plt.colorbar(orientation="vertical", cax=cax)
cb.ax.set_title(
cb_axis_label,
fontsize=primary_fs,
x=cb_label_xpos,
y=cb_label_ypos,
rotation=90,
)
plt.savefig(f"{output_path}{fname}_cbvert.{fformat}", dpi=150, bbox_inches="tight")
elif cb_orient == "horizontal":
cax = plt.axes([0.1, 0.2, 1.15, 0.06])
cb = plt.colorbar(orientation="horizontal", cax=cax)
cb.ax.set_title(cb_axis_label, fontsize=primary_fs, x=cb_label_xpos, y=cb_label_ypos)
plt.savefig(f"{output_path}{fname}_cbhorz.{fformat}", dpi=150, bbox_inches="tight")
return