"""WorkflowRunner — sequence workflow steps using produces/requires metadata.
The runner uses ``@workflow_step`` ``produces`` / ``requires`` lists to
resolve a target product by either (a) returning a cached object built
earlier in this session, (b) loading a persisted product from disk via
``WorkflowProducts``, or (c) running the step that produces it. When
running a step, each of its ``requires`` names is resolved recursively
through the same chain.
Live (non-persistable) objects such as solvers live in the in-memory
cache only. Persistable objects (Mesh, MeshVariable snapshots, Surface,
ndarray) are also saved to disk via ``WorkflowProducts`` and reloaded
across sessions. The runner does not distinguish — ``WorkflowProducts``
decides per type, and save attempts that fail are silently skipped.
Step argument resolution: each parameter of a step function is matched
by *name* to either ``config`` (always passed) or to a name in the
step's ``requires`` list. Steps that produce a single product return
the bare object; steps that produce multiple products return a dict
keyed by product name (preferred) or a tuple aligned with ``produces``.
"""
import inspect
from typing import Optional
def _collect_steps(module):
"""Walk ``module`` for ``@workflow_step`` functions.
Returns a tuple ``(producers, steps)`` where *producers* maps
product name → producing step and *steps* is the list of all
decorated callables.
"""
producers = {}
steps = []
for _name, obj in inspect.getmembers(module, callable):
if not getattr(obj, "_is_workflow_step", False):
continue
steps.append(obj)
for prod in getattr(obj, "workflow_produces", None) or []:
if prod in producers:
raise ValueError(
f"Two workflow steps produce {prod!r}: "
f"{producers[prod].__name__} and {obj.__name__}"
)
producers[prod] = obj
return producers, steps
[docs]
class WorkflowRunner:
"""Resolve and run workflow steps in dependency order.
Parameters
----------
module : module
A workflow module (e.g. ``import convection_config as convection``).
config : WorkflowConfig
Configuration object passed to every step whose signature
includes a ``config`` parameter.
products : WorkflowProducts, optional
Persistence layer. If omitted, products live in memory only.
Examples
--------
>>> runner = WorkflowRunner(convection, config, products)
>>> runner.build("evolution_log") # resolves all dependencies
>>> runner.dag() # show steps + status
>>> runner.rebuild("temperature_initial") # invalidate + rebuild
"""
[docs]
def __init__(self, module, config, products=None):
self.module = module
self.config = config
self.products = products
self.cache = {}
# Per-product cache keys recorded during this session. Used to
# propagate upstream invalidation when computing the expected
# cache key for a downstream product.
self._product_cache_keys: dict[str, str] = {}
# Observers — callbacks registered via .observe() that fire on
# cache hit / disk load / build events. Used by UI layers
# (widgets, dashboards) that want to track DAG progress.
self._observers: list = []
self._producers, self._steps = _collect_steps(module)
# ------------------------------------------------------------------
# Build
# ------------------------------------------------------------------
[docs]
def get(self, name: str):
"""Return product *name*, building (and caching) if needed.
When the config supports ``cache_key``-based freshness tracking,
a disk-cached product is only returned if its recorded cache
key matches what the current config expects; mismatch triggers
a rebuild. Configs without ``_identity_fields`` declared fall
back to existence-based caching (legacy behaviour).
"""
if name in self.cache:
self._emit("cached", name)
return self.cache[name]
if self.products is not None and self.products.exists(name):
expected = self._expected_cache_key(name)
is_fresh = (
self.products.fresh(name, expected)
if expected is not None
else True # legacy: existence is enough
)
if is_fresh:
try:
obj = self.products.load(name)
self.cache[name] = obj
if expected is not None:
self._product_cache_keys[name] = expected
self._emit("loaded", name)
return obj
except Exception:
# Loading failed (e.g. needs a mesh argument we don't have).
# Fall through and rebuild from the producing step.
pass
return self._run_step_for(name)
def _expected_cache_key(self, name: str) -> Optional[str]:
"""Cache key the product *name* should have given the current
config + upstream products.
Returns ``None`` if the config doesn't declare identity fields
(legacy behaviour) — caller should then fall back to
existence-based caching.
"""
if name not in self._producers:
# No producer registered. Can happen for "external"
# products that were saved manually outside the runner's
# DAG; treat as untrackable.
return None
step = self._producers[name]
requires = getattr(step, "workflow_requires", None) or []
upstream = {req: self._expected_cache_key(req) for req in requires}
if any(v is None for v in upstream.values()):
# Incomplete chain — at least one upstream can't be hashed.
# Fall back to no tracking for this product.
return None
return self.config.cache_key(requires=upstream) if hasattr(self.config, "cache_key") else None
# Make ``build`` a synonym of ``get`` — handy in notebooks.
build = get
[docs]
def build_all(self):
"""Build every leaf product (one nothing else requires).
Returns the list of leaf product names that were built.
"""
all_produces = set()
all_requires = set()
for step in self._steps:
all_produces.update(getattr(step, "workflow_produces", None) or [])
all_requires.update(getattr(step, "workflow_requires", None) or [])
leaves = sorted(all_produces - all_requires)
for name in leaves:
self.get(name)
return leaves
def _run_step_for(self, name: str):
if name not in self._producers:
available = sorted(self._producers.keys())
raise KeyError(
f"No workflow step produces {name!r}. "
f"Available products: {available}"
)
step = self._producers[name]
self._emit("building", name)
kwargs = self._resolve_kwargs(step)
result = step(**kwargs)
self._record_outputs(step, result)
self._emit("built", name)
return self.cache[name]
def _resolve_kwargs(self, step):
sig = inspect.signature(step)
requires = set(getattr(step, "workflow_requires", None) or [])
kwargs = {}
for pname, param in sig.parameters.items():
if pname == "config":
kwargs["config"] = self.config
elif pname in requires:
kwargs[pname] = self.get(pname)
elif pname in self.cache:
# Convenience: fall back to cache if the function names
# a parameter that isn't in `requires` but is available.
kwargs[pname] = self.cache[pname]
elif param.default is not inspect.Parameter.empty:
# Use the function's own default.
pass
else:
# Leave unbound; Python will raise a clear TypeError.
pass
return kwargs
def _record_outputs(self, step, result):
produces = getattr(step, "workflow_produces", None) or []
if not produces:
return
if len(produces) == 1:
(name,) = produces
self.cache[name] = result
self._save_quietly(name, result)
return
if isinstance(result, dict):
items = result
elif isinstance(result, (list, tuple)):
if len(result) != len(produces):
raise ValueError(
f"Step {step.__name__} declares {len(produces)} "
f"produces but returned {len(result)} values"
)
items = dict(zip(produces, result))
else:
raise TypeError(
f"Step {step.__name__} declares produces={produces} "
f"but returned a {type(result).__name__}; "
"expected dict or tuple aligned with produces"
)
for key, value in items.items():
self.cache[key] = value
self._save_quietly(key, value)
def _save_quietly(self, name, obj):
if self.products is None:
return
try:
expected = self._expected_cache_key(name)
if expected is not None:
# Build the inputs block for the manifest entry, so a
# human reader can audit what changed when a key shifts.
# Use _expected_cache_key (not _product_cache_keys) so
# non-persistable upstream products (live solvers) still
# contribute their computed digest to the audit.
step = self._producers.get(name)
requires = (
getattr(step, "workflow_requires", None) or [] if step else []
)
inputs = {
"config": self._identity_snapshot(),
"requires": {
req: self._expected_cache_key(req) for req in requires
},
}
self.products.save(name, obj, cache_key=expected, inputs=inputs)
self._product_cache_keys[name] = expected
else:
# Legacy / no identity fields declared — skip cache_key.
self.products.save(name, obj)
except Exception:
# Not persistable (e.g. live solver) — that's fine.
pass
def _identity_snapshot(self) -> dict:
"""Snapshot of the config's identity fields, for manifest audit."""
ids = getattr(self.config, "_identity_fields", None)
if ids is None:
return {}
return {f: getattr(self.config, f) for f in ids}
# ------------------------------------------------------------------
# Invalidation
# ------------------------------------------------------------------
[docs]
def invalidate(self, name: str):
"""Drop *name* from cache and remove its on-disk product if any.
Does **not** invalidate downstream products. Caller is
responsible for rebuilding anything that depended on this.
"""
self.cache.pop(name, None)
if self.products is not None and self.products.exists(name):
self.products.remove(name)
[docs]
def rebuild(self, name: str):
"""Invalidate and rebuild a single product."""
self.invalidate(name)
return self.get(name)
# ------------------------------------------------------------------
# Inspection
# ------------------------------------------------------------------
[docs]
def status(self, name: str) -> str:
"""Return one of ``'cached'``, ``'on_disk'``, ``'missing'``."""
if name in self.cache:
return "cached"
if self.products is not None and self.products.exists(name):
return "on_disk"
return "missing"
# ------------------------------------------------------------------
# Observability hooks (for UI layers — widgets, dashboards)
# ------------------------------------------------------------------
[docs]
def observe(self, callback) -> None:
"""Register a callback fired on product cache/load/build events.
Each registered *callback* is invoked as
``callback(event, name)`` where *event* is one of:
* ``"cached"`` — returned from in-memory cache
* ``"loaded"`` — read from on-disk products
* ``"building"`` — about to run the producing step
* ``"built"`` — step finished, product cached and (if
persistable) saved
Callbacks that raise are silently ignored so a buggy observer
can't break the runner.
Used by interactive UI layers (Jupyter widgets, panel
dashboards) that want to display DAG progress without polling.
"""
self._observers.append(callback)
def _emit(self, event: str, name: str) -> None:
for cb in self._observers:
try:
cb(event, name)
except Exception:
pass
[docs]
def diagram(self, *, rankdir: str = "LR") -> str:
"""Generate a Graphviz DOT diagram of this runner's DAG.
Nodes are colour-coded by current status (cached / on-disk /
missing) using :meth:`status` as the status provider.
Returns the DOT source string; render with the ``dot`` command::
Path("dag.dot").write_text(runner.diagram())
# then in a shell: dot -Tpng dag.dot -o dag.png
Or use :func:`underworld3.workflows.render` for a one-call
wrapper that invokes ``dot`` directly.
"""
from ._diagram import diagram as _diagram
return _diagram(self.module, status_provider=self.status, rankdir=rankdir)
[docs]
def what_invalidates(self, name: str) -> set:
"""Set of products that would rebuild if *name* changed.
Walks the produces/requires DAG forward from *name* and
returns every product (transitively) that lists *name* in
its requires chain. Useful for UI layers showing
"if you change X, these will rebuild".
``name`` itself is *not* included in the returned set.
"""
# Direct dependents — products whose producing step lists
# *name* in its requires.
direct = {
prod
for prod, step in self._producers.items()
if name in (getattr(step, "workflow_requires", None) or [])
}
# Transitively walk forward.
seen = set()
frontier = list(direct)
while frontier:
p = frontier.pop()
if p in seen:
continue
seen.add(p)
for prod, step in self._producers.items():
if p in (getattr(step, "workflow_requires", None) or []):
if prod not in seen:
frontier.append(prod)
return seen
[docs]
def dag(self):
"""Display all steps with their produces / requires / status.
In Jupyter this renders as an HTML table; falls back to plain
text in a terminal.
"""
rows = []
for step in sorted(self._steps, key=lambda f: f.__name__):
produces = getattr(step, "workflow_produces", None) or []
requires = getattr(step, "workflow_requires", None) or []
req_str = ", ".join(requires) if requires else "—"
if not produces:
rows.append((step.__name__, "—", req_str, "(no product)"))
else:
for prod in produces:
rows.append(
(step.__name__, prod, req_str, self.status(prod))
)
try:
from IPython.display import HTML, display
th = (
'<th style="text-align:left; padding:4px 12px 4px 0; '
'border-bottom:2px solid #ccc;">'
)
html = (
f"<h4>{self.module.__name__} — runner status</h4>"
f'<table style="border-collapse:collapse;"><tr>'
f"{th}Step</th>{th}Produces</th>{th}Requires</th>"
f"{th}Status</th></tr>"
)
for step_name, prod, reqs, status in rows:
colour = {
"cached": "#060",
"on_disk": "#063",
"missing": "#888",
"(no product)": "#888",
}.get(status, "#888")
td = '<td style="padding:2px 12px 2px 0;'
html += (
"<tr>"
f'{td} font-family:monospace;">{step_name}</td>'
f'{td} font-family:monospace; color:#555;">{prod}</td>'
f'{td} font-family:monospace; color:#555;">{reqs}</td>'
f'{td} color:{colour}; font-weight:bold;">{status}</td>'
"</tr>"
)
html += "</table>"
display(HTML(html))
except ImportError:
name_w = max(len(r[0]) for r in rows)
prod_w = max(len(r[1]) for r in rows)
req_w = max(len(r[2]) for r in rows)
print(f"{self.module.__name__} — runner status")
print(
f" {'Step':<{name_w}} {'Produces':<{prod_w}} "
f"{'Requires':<{req_w}} Status"
)
print(
f" {'-' * name_w} {'-' * prod_w} "
f"{'-' * req_w} ------"
)
for step_name, prod, reqs, status in rows:
print(
f" {step_name:<{name_w}} {prod:<{prod_w}} "
f"{reqs:<{req_w}} {status}"
)