Source code for underworld3.utilities.nd_array_callback

"""
NDArray_With_Callback: A numpy ndarray subclass with modification callbacks.

This class is designed to help wrap underworld data that require us to
do parallel sync or PETSc object refreshing.

Key Features:
- Callbacks triggered when array data is modified
- Delayed callback execution for batch operations
- MPI synchronization in parallel contexts
- Global reduction operations (MPI-aware): global_max, global_min, global_sum,
  global_mean, global_size, global_norm, global_rms
- Weak reference ownership tracking

This is the base class for UnitAwareArray which adds unit preservation.
"""

import numpy as np
import weakref
import logging
from typing import Callable, Any, Dict, List, Optional, Union
import threading

logger = logging.getLogger(__name__)

# Try to import underworld MPI - fall back gracefully if not available
try:
    import underworld3 as uw

    _has_uw_mpi = hasattr(uw, "mpi") and hasattr(uw.mpi, "barrier")
except ImportError:
    _has_uw_mpi = False
    uw = None


class DelayedCallbackManager:
    """
    Thread-local manager for delayed callbacks across multiple NDArray_With_Callback instances.

    This allows batch operations across multiple arrays to accumulate callbacks
    and trigger them all at once when the context exits.
    """

    def __init__(self):
        self._local = threading.local()

    def _get_state(self):
        """Get or create thread-local state."""
        if not hasattr(self._local, "delay_stack"):
            self._local.delay_stack = []
            self._local.delayed_callbacks = []
        return self._local

    def is_delaying(self):
        """Check if callbacks are currently being delayed."""
        state = self._get_state()
        return len(state.delay_stack) > 0

    def push_delay_context(self, context_info=None):
        """Enter a new delay context."""
        state = self._get_state()
        state.delay_stack.append(
            {
                "context_info": context_info,
                "callback_count": len(state.delayed_callbacks),
            }
        )

    def pop_delay_context(self):
        """Exit delay context and return callbacks accumulated in this context."""
        state = self._get_state()
        if not state.delay_stack:
            return []

        context = state.delay_stack.pop()
        start_idx = context["callback_count"]

        # Get callbacks from this context level
        context_callbacks = state.delayed_callbacks[start_idx:]

        # If we're exiting the outermost context, clear all callbacks
        if not state.delay_stack:
            state.delayed_callbacks.clear()
        else:
            # Remove only this context's callbacks (keep outer context callbacks)
            state.delayed_callbacks = state.delayed_callbacks[:start_idx]

        return context_callbacks

    def add_delayed_callback(self, array, callback_func, change_info):
        """Add a callback to the delayed execution queue."""
        state = self._get_state()
        state.delayed_callbacks.append(
            {
                "array": array,
                "callback": callback_func,
                "change_info": change_info.copy(),
            }
        )


# Global instance for managing delayed callbacks
_delayed_callback_manager = DelayedCallbackManager()


[docs] class NDArray_With_Callback(np.ndarray): """A numpy ndarray subclass that triggers callbacks when array data is modified. This class maintains full numpy array compatibility while providing reactive programming capabilities for scientific computing applications. **Callback Function Signature**:: def callback(array: NDArray_With_Callback, change_info: dict) -> None: pass The ``change_info`` dictionary contains: - ``operation`` (str): Operation name ('setitem', 'iadd', 'fill', etc.) - ``indices`` (tuple/slice/None): Location of change (for setitem operations) - ``old_value`` (array-like/None): Previous values (when available) - ``new_value`` (array-like): New values being assigned - ``array_shape`` (tuple): Current shape of the array - ``array_dtype`` (np.dtype): Data type of the array **Features**: - **Multiple callbacks**: ``add_callback()``, ``remove_callback()``, ``clear_callbacks()`` - **Enable/disable**: ``enable_callbacks()``, ``disable_callbacks()`` - **Delayed execution**: ``delay_callback()``, ``delay_callbacks_global()`` - **MPI synchronization**: Automatic barriers in parallel contexts - **Weak references**: Owner tracking without circular dependencies - **Global reductions**: MPI-aware ``global_max()``, ``global_min()``, ``global_sum()``, etc. **Global Reduction Operations (MPI-aware)**: - ``global_max(axis=None)``: Maximum value across all MPI ranks - ``global_min(axis=None)``: Minimum value across all MPI ranks - ``global_sum(axis=None)``: Sum of all values across all MPI ranks - ``global_mean(axis=None)``: True mean (global sum / global count) - ``global_size()``: Total number of elements across all ranks - ``global_norm(ord=2)``: 2-norm (Euclidean) across all ranks - ``global_rms()``: Root mean square across all ranks These methods use MPI collective operations (``allreduce``). All ranks must call these methods (they are collective operations). Subclasses like ``UnitAwareArray`` override these to preserve units. """
[docs] def __new__(cls, input_array=None, owner=None, callback=None, disable_inplace_operators=False): """ Create new NDArray_With_Callback instance. Parameters ---------- input_array : array-like, optional Input data to create array from (defaults to empty array if None) owner : object, optional The object that owns this array (stored as weak reference) callback : callable, optional Initial callback function to register disable_inplace_operators : bool, optional If True, in-place operators (``+=``, ``-=``, ``*=``, ``/=``, etc.) will raise RuntimeError for parallel safety. Default is False for backward compatibility. """ if input_array is None: input_array = [] # Create the ndarray instance obj = np.asarray(input_array).view(cls) # Initialize callback system obj._callbacks = [] obj._owner = weakref.ref(owner) if owner is not None else None obj._callback_enabled = True obj._disable_inplace_operators = disable_inplace_operators # Register initial callback if provided if callback is not None: obj._callbacks.append(callback) return obj
def __array_finalize__(self, obj): """ Called whenever the system allocates a new array from this template. """ if obj is None: return # Copy callback information from parent array self._callbacks = getattr(obj, "_callbacks", []).copy() self._owner = getattr(obj, "_owner", None) self._callback_enabled = getattr(obj, "_callback_enabled", True) self._disable_inplace_operators = getattr(obj, "_disable_inplace_operators", False) # === numpy.ma (masked array) compatibility === # These attributes are needed when numpy's masked array operations # interact with our array subclass. @property def _mask(self): """For numpy.ma compatibility - we have no mask.""" return np.ma.nomask @_mask.setter def _mask(self, value): """For numpy.ma compatibility - ignore mask setting.""" # We don't support masking, so ignore attempts to set a mask pass @property def mask(self): """Public mask property for numpy.ma compatibility. Matplotlib's quiver and other plotting functions access .mask directly. This aliases to _mask which returns np.ma.nomask (no masking). """ return self._mask @mask.setter def mask(self, value): """Public mask setter for numpy.ma compatibility.""" self._mask = value
[docs] def filled(self, fill_value=None): """Return array with masked values filled. For numpy.ma compatibility. Since we have no mask, this just returns a copy of the data (as numpy array to avoid further masked array operations). Parameters ---------- fill_value : scalar, optional Value used to fill masked entries. Ignored since we have no mask. Returns ------- ndarray A copy of the data as a plain numpy array. """ return np.asarray(self).copy()
def _update_from(self, obj): """For numpy.ma compatibility - update from another array.""" # This is used by masked array operations to update data if hasattr(obj, '__array__'): np.copyto(self, np.asarray(obj)) elif obj is not None: np.copyto(self, obj) def __array_wrap__(self, result, context=None, return_scalar=False): """ Called after numpy operations to wrap results back to our type. Parameters updated for NumPy 2.0 compatibility: - context: Information about the ufunc that produced the result (unused) - return_scalar: If True, return a scalar instead of 0-d array """ if return_scalar or result.shape == (): # Scalar result, return as numpy scalar return result.item() # For in-place operations that return the same array, return self # Use numpy's view to avoid recursion try: self_as_ndarray = np.ndarray.view(self, np.ndarray) if result is self_as_ndarray or ( hasattr(result, "base") and hasattr(self, "base") and result.base is self.base ): return self except: # If view comparison fails, fall back to simple check pass # For new array results, don't automatically wrap to our type # This prevents issues with operations that shouldn't preserve callbacks return np.asarray(result)
[docs] def set_callback(self, callback: Callable): """ Set a single callback function (replaces any existing callbacks). Parameters ---------- callback : callable Function with signature: callback(array, change_info) - array: the NDArray_With_Callback instance - change_info: dict with operation details """ self._callbacks = [callback] if callback is not None else []
[docs] def add_callback(self, callback: Callable): """ Add an additional callback function. Parameters ---------- callback : callable Function to add to callback list """ if callback is not None and callback not in self._callbacks: self._callbacks.append(callback)
[docs] def remove_callback(self, callback: Callable): """ Remove a specific callback function. Parameters ---------- callback : callable Function to remove from callback list """ if callback in self._callbacks: self._callbacks.remove(callback)
[docs] def clear_callbacks(self): """Remove all registered callbacks.""" self._callbacks.clear()
[docs] def enable_callbacks(self): """Enable callback triggering.""" self._callback_enabled = True
[docs] def disable_callbacks(self): """Disable callback triggering (useful for batch operations).""" self._callback_enabled = False
@property def owner(self): """Get the owner object (may be None if owner was garbage collected).""" return self._owner() if self._owner is not None else None
[docs] def delay_callback(self, context_info=None): """ Context manager to delay callback execution until context exit. During the context, all callbacks from this array (and any other arrays using delay_callback) will be accumulated and executed when the outermost context exits. Parameters ---------- context_info : str, optional Optional information about the context (for debugging) Example ------- >>> with arr.delay_callback("batch update"): ... arr[0] = 1 ... arr[1] = 2 ... arr[2] = 3 # All callbacks fire here at context exit """ class DelayCallbackContext: def __init__(self, context_info): self.context_info = context_info def __enter__(self): # MPI barrier to ensure all processes enter delay context together if _has_uw_mpi: try: uw.mpi.barrier() except Exception as e: logger.warning(f"MPI barrier failed on delay context enter: {e}") _delayed_callback_manager.push_delay_context(self.context_info) return self def __exit__(self, exc_type, exc_val, exc_tb): # Get callbacks accumulated during this context delayed_callbacks = _delayed_callback_manager.pop_delay_context() # MPI barrier to ensure all processes finish their delayed operations # before any process starts executing callbacks if _has_uw_mpi: try: uw.mpi.barrier() except Exception as e: logger.warning(f"MPI barrier failed before delayed callback execution: {e}") # Execute all delayed callbacks for callback_item in delayed_callbacks: try: callback_item["callback"]( callback_item["array"], callback_item["change_info"] ) except Exception as e: logger.warning(f"Delayed callback error: {e}") # MPI barrier to ensure all processes complete their callbacks # before any process exits the context if _has_uw_mpi: try: uw.mpi.barrier() except Exception as e: logger.warning(f"MPI barrier failed after delayed callback execution: {e}") # Don't suppress exceptions from the context return False return DelayCallbackContext(context_info)
[docs] @staticmethod def delay_callbacks_global(context_info=None): """ Static method to create a global delay context for all NDArray_With_Callback instances. This is useful when you don't have a specific array instance but want to delay callbacks from multiple arrays. Example ------- >>> with NDArray_With_Callback.delay_callbacks_global("mesh update"): ... mesh.data[0] = new_pos ... swarm.data += displacement # All callbacks from all arrays fire here """ class GlobalDelayCallbackContext: def __init__(self, context_info): self.context_info = context_info def __enter__(self): # MPI barrier to ensure all processes enter delay context together if _has_uw_mpi: try: uw.mpi.barrier() except Exception as e: logger.warning(f"MPI barrier failed on global delay context enter: {e}") _delayed_callback_manager.push_delay_context(self.context_info) return self def __exit__(self, exc_type, exc_val, exc_tb): # Get callbacks accumulated during this context delayed_callbacks = _delayed_callback_manager.pop_delay_context() # MPI barrier to ensure all processes finish their delayed operations # before any process starts executing callbacks if _has_uw_mpi: try: uw.mpi.barrier() except Exception as e: logger.warning( f"MPI barrier failed before global delayed callback execution: {e}" ) # Execute all delayed callbacks for callback_item in delayed_callbacks: try: callback_item["callback"]( callback_item["array"], callback_item["change_info"] ) except Exception as e: logger.warning(f"Delayed callback error: {e}") # MPI barrier to ensure all processes complete their callbacks # before any process exits the context if _has_uw_mpi: try: uw.mpi.barrier() except Exception as e: logger.warning( f"MPI barrier failed after global delayed callback execution: {e}" ) return False return GlobalDelayCallbackContext(context_info)
def _trigger_callback( self, operation: str, indices=None, old_value=None, new_value=None, data_has_changed=True ): """ Internal method to trigger all registered callbacks. Parameters ---------- operation : str Name of the operation that triggered the callback indices : tuple or slice, optional Indices that were modified old_value : array-like, optional Previous value(s) at the modified location new_value : array-like, optional New value(s) at the modified location data_has_changed : bool, optional Whether this operation may have changed the array data (default True) """ if not self._callback_enabled or not self._callbacks: return change_info = { "operation": operation, "indices": indices, "old_value": old_value, "new_value": new_value, "array_shape": self.shape, "array_dtype": self.dtype, "data_has_changed": data_has_changed, } # Check if we're in a delay callback context if _delayed_callback_manager.is_delaying(): # Add callbacks to the delayed execution queue for callback in self._callbacks: _delayed_callback_manager.add_delayed_callback(self, callback, change_info) else: # Execute callbacks immediately for callback in self._callbacks.copy(): # Copy in case callbacks modify the list try: callback(self, change_info) except Exception as e: logger.warning(f"Callback error in {callback}: {e}") def __setitem__(self, key, value): """Override setitem to trigger callbacks on assignment.""" if self._callback_enabled and self._callbacks: try: old_value = self[key].copy() if hasattr(self[key], "copy") else self[key] except (IndexError, ValueError): old_value = None else: old_value = None # Handle UnitAwareArray values by extracting magnitude # This allows: T.array[...] = uw.function.evaluate(...) where evaluate returns UnitAwareArray # Without this, numpy raises "only length-1 arrays can be converted to Python scalars" actual_value = value if hasattr(value, 'magnitude'): # UnitAwareArray or similar - extract the raw numeric data actual_value = value.magnitude # Perform the actual assignment super().__setitem__(key, actual_value) # Trigger callbacks self._trigger_callback("setitem", indices=key, old_value=old_value, new_value=value) def __iadd__(self, other): """In-place addition with callback.""" if self._disable_inplace_operators: raise RuntimeError( "In-place addition (+=) is disabled for parallel safety. " "Use explicit assignment instead: arr = arr + other" ) if self._callback_enabled and self._callbacks: old_value = self.copy() else: old_value = None result = super().__iadd__(other) self._trigger_callback("iadd", old_value=old_value, new_value=other) return result def __isub__(self, other): """In-place subtraction with callback.""" if self._disable_inplace_operators: raise RuntimeError( "In-place subtraction (-=) is disabled for parallel safety. " "Use explicit assignment instead: arr = arr - other" ) if self._callback_enabled and self._callbacks: old_value = self.copy() else: old_value = None result = super().__isub__(other) self._trigger_callback("isub", old_value=old_value, new_value=other) return result def __imul__(self, other): """In-place multiplication with callback.""" if self._disable_inplace_operators: raise RuntimeError( "In-place multiplication (*=) is disabled for parallel safety. " "Use explicit assignment instead: arr = arr * other" ) if self._callback_enabled and self._callbacks: old_value = self.copy() else: old_value = None result = super().__imul__(other) self._trigger_callback("imul", old_value=old_value, new_value=other) return result def __itruediv__(self, other): """In-place true division with callback.""" if self._disable_inplace_operators: raise RuntimeError( "In-place division (/=) is disabled for parallel safety. " "Use explicit assignment instead: arr = arr / other" ) if self._callback_enabled and self._callbacks: old_value = self.copy() else: old_value = None result = super().__itruediv__(other) self._trigger_callback("itruediv", old_value=old_value, new_value=other) return result def __ifloordiv__(self, other): """In-place floor division with callback.""" if self._disable_inplace_operators: raise RuntimeError( "In-place floor division (//=) is disabled for parallel safety. " "Use explicit assignment instead: arr = arr // other" ) if self._callback_enabled and self._callbacks: old_value = self.copy() else: old_value = None result = super().__ifloordiv__(other) self._trigger_callback("ifloordiv", old_value=old_value, new_value=other) return result def __imod__(self, other): """In-place modulo with callback.""" if self._disable_inplace_operators: raise RuntimeError( "In-place modulo (%=) is disabled for parallel safety. " "Use explicit assignment instead: arr = arr % other" ) if self._callback_enabled and self._callbacks: old_value = self.copy() else: old_value = None result = super().__imod__(other) self._trigger_callback("imod", old_value=old_value, new_value=other) return result def __ipow__(self, other): """In-place power with callback.""" if self._disable_inplace_operators: raise RuntimeError( "In-place power (**=) is disabled for parallel safety. " "Use explicit assignment instead: arr = arr ** other" ) if self._callback_enabled and self._callbacks: old_value = self.copy() else: old_value = None result = super().__ipow__(other) self._trigger_callback("ipow", old_value=old_value, new_value=other) return result def __iand__(self, other): """In-place bitwise and with callback.""" if self._disable_inplace_operators: raise RuntimeError( "In-place bitwise and (&=) is disabled for parallel safety. " "Use explicit assignment instead: arr = arr & other" ) if self._callback_enabled and self._callbacks: old_value = self.copy() else: old_value = None result = super().__iand__(other) self._trigger_callback("iand", old_value=old_value, new_value=other) return result def __ior__(self, other): """In-place bitwise or with callback.""" if self._disable_inplace_operators: raise RuntimeError( "In-place bitwise or (|=) is disabled for parallel safety. " "Use explicit assignment instead: arr = arr | other" ) if self._callback_enabled and self._callbacks: old_value = self.copy() else: old_value = None result = super().__ior__(other) self._trigger_callback("ior", old_value=old_value, new_value=other) return result def __ixor__(self, other): """In-place bitwise xor with callback.""" if self._disable_inplace_operators: raise RuntimeError( "In-place bitwise xor (^=) is disabled for parallel safety. " "Use explicit assignment instead: arr = arr ^ other" ) if self._callback_enabled and self._callbacks: old_value = self.copy() else: old_value = None result = super().__ixor__(other) self._trigger_callback("ixor", old_value=old_value, new_value=other) return result def __ilshift__(self, other): """In-place left shift with callback.""" if self._disable_inplace_operators: raise RuntimeError( "In-place left shift (<<=) is disabled for parallel safety. " "Use explicit assignment instead: arr = arr << other" ) if self._callback_enabled and self._callbacks: old_value = self.copy() else: old_value = None result = super().__ilshift__(other) self._trigger_callback("ilshift", old_value=old_value, new_value=other) return result def __irshift__(self, other): """In-place right shift with callback.""" if self._disable_inplace_operators: raise RuntimeError( "In-place right shift (>>=) is disabled for parallel safety. " "Use explicit assignment instead: arr = arr >> other" ) if self._callback_enabled and self._callbacks: old_value = self.copy() else: old_value = None result = super().__irshift__(other) self._trigger_callback("irshift", old_value=old_value, new_value=other) return result
[docs] def fill(self, value): """Fill array with scalar value, triggering callback.""" if self._callback_enabled and self._callbacks: old_value = self.copy() else: old_value = None super().fill(value) self._trigger_callback("fill", old_value=old_value, new_value=value)
[docs] def sort(self, axis=-1, kind=None, order=None): """Sort array in-place, triggering callback.""" if self._callback_enabled and self._callbacks: old_value = self.copy() else: old_value = None super().sort(axis=axis, kind=kind, order=order) self._trigger_callback("sort", old_value=old_value)
[docs] def resize(self, new_shape, refcheck=True): """Resize array in-place, triggering callback.""" if self._callback_enabled and self._callbacks: old_value = self.copy() old_shape = self.shape else: old_value = None old_shape = None super().resize(new_shape, refcheck=refcheck) self._trigger_callback("resize", old_value=old_value, new_value=new_shape)
[docs] def copy(self, order="C"): """ Return a copy of the array. The copy will have the same callbacks registered but will be independent. """ result = super().copy(order=order).view(NDArray_With_Callback) result._callbacks = self._callbacks.copy() result._owner = self._owner result._callback_enabled = self._callback_enabled result._disable_inplace_operators = self._disable_inplace_operators return result
[docs] def view(self, dtype=None, type=None): """ Return a view of the array. Views share callbacks with the original array. """ # Use numpy's ndarray.view directly to avoid recursion if type is None and dtype is None: # Simple view with same type and dtype result = np.ndarray.view(self, NDArray_With_Callback) elif type is None: # View with different dtype, then cast to our type temp_view = np.ndarray.view(self, dtype) result = np.ndarray.view(temp_view, NDArray_With_Callback) else: # Use specified type (may not be our type) result = np.ndarray.view(self, dtype, type) # Copy our attributes to the result if it's our type if isinstance(result, NDArray_With_Callback): result._callbacks = self._callbacks # Share callbacks (not copy) result._owner = self._owner result._callback_enabled = self._callback_enabled result._disable_inplace_operators = self._disable_inplace_operators return result
[docs] def sync_data(self, new_data): """ Update array with new data, preserving callbacks and all metadata. This method efficiently handles both same-size and different-size data updates. For same-size updates, it uses efficient in-place copying. For different sizes, it creates a new array object but preserves all metadata and callbacks. Parameters ---------- new_data : array-like New data to sync into this array. Can be different size/shape. Returns ------- result : NDArray_With_Callback For same-size: returns self (same object) For different-size: returns new object with same metadata Notes ----- - For same-size data: Uses efficient in-place copy (preserves object identity) - For different sizes: Creates new object but copies all callbacks/metadata - All callbacks, owner references, and settings are preserved - Triggers 'sync_data' callback after update Examples -------- >>> arr = NDArray_With_Callback([1, 2, 3]) >>> result = arr.sync_data([4, 5, 6]) # Same size: returns same object >>> assert result is arr >>> result = arr.sync_data([7, 8, 9, 10, 11]) # Different size: new object >>> assert result is not arr # Different object >>> assert len(result._callbacks) == len(arr._callbacks) # Same callbacks """ new_array = np.asarray(new_data) # Store old info for callback if self._callback_enabled and self._callbacks: old_data = self.copy() else: old_data = None if new_array.shape == self.shape and new_array.dtype == self.dtype: # Same size and dtype: ultra-efficient in-place copy np.copyto(self, new_array) # Trigger callback for the sync operation self._trigger_callback( "sync_data", old_value=old_data, new_value=new_array, indices=None, # Full array update data_has_changed=False, # Sync operation doesn't represent user data change ) return self else: # Different size/dtype: create new object with same metadata # This is more reliable than trying to modify the existing array new_obj = type(self)( new_array, owner=self._owner() if self._owner is not None else None, disable_inplace_operators=self._disable_inplace_operators, ) # Copy all callbacks and settings new_obj._callbacks = self._callbacks.copy() new_obj._callback_enabled = self._callback_enabled # Trigger callback on the new object new_obj._trigger_callback( "sync_data", old_value=old_data, new_value=new_array, indices=None, data_has_changed=False, # Sync operation doesn't represent user data change ) return new_obj
def __reduce__(self): """Support for pickling.""" # Get the parent's reduce result pickled_state = super().__reduce__() # Add our custom attributes to the state new_state = pickled_state[2] + ( self._callbacks, self._owner, self._callback_enabled, self._disable_inplace_operators, ) return (pickled_state[0], pickled_state[1], new_state) def __setstate__(self, state): """Support for unpickling.""" # Split our custom attributes from the parent's state parent_state = state[:-4] self._callbacks, self._owner, self._callback_enabled, self._disable_inplace_operators = ( state[-4:] ) # Call parent's setstate super().__setstate__(parent_state) def __repr__(self): """String representation showing callback information.""" base_repr = super().__repr__() callback_info = f", callbacks={len(self._callbacks)}" # Insert callback info before the closing parenthesis if base_repr.startswith("array(") and base_repr.endswith(")"): return base_repr[:-1] + callback_info + ")" else: return base_repr + callback_info # === GLOBAL REDUCTION OPERATIONS (MPI-aware) === # These operations reduce across all MPI ranks. # Subclasses (like UnitAwareArray) can override to add unit preservation.
[docs] def global_max(self, axis=None, out=None, keepdims=False): """ Return maximum across all MPI ranks. For scalar results (axis=None), performs MPI reduction. For array results, performs component-wise maximum. Parameters ---------- axis : None or int or tuple of ints, optional Axis along which to operate (default: None = reduce all dimensions) out : ndarray, optional Alternative output array keepdims : bool, optional Keep reduced dimensions as size 1 (default: False) Returns ------- scalar or ndarray Global maximum value(s) """ from mpi4py import MPI # Try to get underworld MPI comm, fall back to MPI.COMM_WORLD try: import underworld3 as uw comm = uw.mpi.comm except (ImportError, AttributeError): comm = MPI.COMM_WORLD # Handle empty arrays (use -inf as identity for max) if self.size == 0: if axis is None and not keepdims: local_max = -np.inf else: # Determine result shape for empty array if axis is None: result_shape = tuple() elif keepdims: result_shape = list(self.shape) if isinstance(axis, int): result_shape[axis] = 1 else: for ax in axis: result_shape[ax] = 1 result_shape = tuple(result_shape) else: result_shape = tuple( s for i, s in enumerate(self.shape) if i not in (axis if isinstance(axis, tuple) else (axis,)) ) local_max = np.full(result_shape, -np.inf) else: local_max = np.asarray(self).max(axis=axis, out=out, keepdims=keepdims) # Scalar result - perform MPI reduction if axis is None and not keepdims: return comm.allreduce(float(local_max), op=MPI.MAX) # Array result - component-wise reduction local_arr = np.asarray(local_max) if local_arr.ndim == 1: global_arr = np.array([ comm.allreduce(float(local_arr[i]), op=MPI.MAX) for i in range(len(local_arr)) ]) else: global_arr = np.empty_like(local_arr) comm.Allreduce(local_arr, global_arr, op=MPI.MAX) return global_arr
[docs] def global_min(self, axis=None, out=None, keepdims=False): """ Return minimum across all MPI ranks. For scalar results (axis=None), performs MPI reduction. For array results, performs component-wise minimum. Parameters ---------- axis : None or int or tuple of ints, optional Axis along which to operate (default: None = reduce all dimensions) out : ndarray, optional Alternative output array keepdims : bool, optional Keep reduced dimensions as size 1 (default: False) Returns ------- scalar or ndarray Global minimum value(s) """ from mpi4py import MPI try: import underworld3 as uw comm = uw.mpi.comm except (ImportError, AttributeError): comm = MPI.COMM_WORLD # Handle empty arrays (use +inf as identity for min) if self.size == 0: if axis is None and not keepdims: local_min = np.inf else: if axis is None: result_shape = tuple() elif keepdims: result_shape = list(self.shape) if isinstance(axis, int): result_shape[axis] = 1 else: for ax in axis: result_shape[ax] = 1 result_shape = tuple(result_shape) else: result_shape = tuple( s for i, s in enumerate(self.shape) if i not in (axis if isinstance(axis, tuple) else (axis,)) ) local_min = np.full(result_shape, np.inf) else: local_min = np.asarray(self).min(axis=axis, out=out, keepdims=keepdims) # Scalar result if axis is None and not keepdims: return comm.allreduce(float(local_min), op=MPI.MIN) # Array result local_arr = np.asarray(local_min) if local_arr.ndim == 1: global_arr = np.array([ comm.allreduce(float(local_arr[i]), op=MPI.MIN) for i in range(len(local_arr)) ]) else: global_arr = np.empty_like(local_arr) comm.Allreduce(local_arr, global_arr, op=MPI.MIN) return global_arr
[docs] def global_sum(self, axis=None, dtype=None, out=None, keepdims=False): """ Return sum across all MPI ranks. For scalar results (axis=None), performs MPI reduction. For array results, performs component-wise sum. Parameters ---------- axis : None or int or tuple of ints, optional Axis along which to operate (default: None = reduce all dimensions) dtype : data-type, optional Type of returned array out : ndarray, optional Alternative output array keepdims : bool, optional Keep reduced dimensions as size 1 (default: False) Returns ------- scalar or ndarray Global sum value(s) """ from mpi4py import MPI try: import underworld3 as uw comm = uw.mpi.comm except (ImportError, AttributeError): comm = MPI.COMM_WORLD local_sum = np.asarray(self).sum(axis=axis, dtype=dtype, out=out, keepdims=keepdims) # Scalar result if axis is None and not keepdims: return comm.allreduce(float(local_sum), op=MPI.SUM) # Array result local_arr = np.asarray(local_sum) if local_arr.ndim == 1: global_arr = np.array([ comm.allreduce(float(local_arr[i]), op=MPI.SUM) for i in range(len(local_arr)) ]) else: global_arr = np.empty_like(local_arr) comm.Allreduce(local_arr, global_arr, op=MPI.SUM) return global_arr
[docs] def global_mean(self, axis=None, dtype=None, out=None, keepdims=False): """ Return mean across all MPI ranks. Computes the true global mean by summing all values across ranks and dividing by total count. Parameters ---------- axis : None or int or tuple of ints, optional Axis along which to operate (default: None = reduce all dimensions) dtype : data-type, optional Type of returned array out : ndarray, optional Alternative output array keepdims : bool, optional Keep reduced dimensions as size 1 (default: False) Returns ------- scalar or ndarray Global mean value(s) """ from mpi4py import MPI try: import underworld3 as uw comm = uw.mpi.comm except (ImportError, AttributeError): comm = MPI.COMM_WORLD # Get local count if axis is None: local_count = self.size elif isinstance(axis, int): local_count = self.shape[axis] else: local_count = np.prod([self.shape[ax] for ax in axis]) # Get global sum and count global_sum = self.global_sum(axis=axis, dtype=dtype, keepdims=keepdims) global_count = comm.allreduce(local_count, op=MPI.SUM) # Compute mean if axis is None and not keepdims: return float(global_sum) / global_count else: return np.asarray(global_sum) / global_count
[docs] def global_size(self): """ Return total number of elements across all MPI ranks. Useful for computing global statistics that require total element count. Returns ------- int Total number of elements summed across all MPI ranks """ from mpi4py import MPI try: import underworld3 as uw comm = uw.mpi.comm except (ImportError, AttributeError): comm = MPI.COMM_WORLD return comm.allreduce(self.size, op=MPI.SUM)
[docs] def global_norm(self, ord=None): """ Return 2-norm across all MPI ranks. Computes sqrt(sum of squares) across all ranks. Parameters ---------- ord : {None, 2}, optional Order of the norm (only 2-norm supported, default: None = 2-norm) Returns ------- float Global 2-norm value """ from mpi4py import MPI try: import underworld3 as uw comm = uw.mpi.comm except (ImportError, AttributeError): comm = MPI.COMM_WORLD if ord is not None and ord != 2: raise NotImplementedError( f"global_norm() only supports ord=None or ord=2 (2-norm), got ord={ord}" ) # Compute local sum of squares local_arr = np.asarray(self) local_sq_sum = np.sum(local_arr**2) # Global sum of squares global_sq_sum = comm.allreduce(float(local_sq_sum), op=MPI.SUM) return np.sqrt(global_sq_sum)
[docs] def global_rms(self): """ Return root mean square across all MPI ranks. Computes RMS = sqrt(sum of squares / total count) across all ranks. Returns ------- float Global RMS value """ norm = self.global_norm() size = self.global_size() return norm / np.sqrt(size)