Add a code generation pipeline that compiles Expr DAGs into flat Python functions, eliminating recursive tree-walk dispatch in the Newton-Raphson inner loop. Key changes: - Add to_code() method to all 11 Expr node types (expr.py) - New codegen.py module with CSE (common subexpression elimination), sparsity detection, and compile()/exec() compilation pipeline - Add ParamTable.env_ref() to avoid dict copies per iteration (params.py) - Newton and BFGS solvers accept pre-built jac_exprs and compiled_eval to avoid redundant diff/simplify and enable compiled evaluation - count_dof() and diagnostics accept pre-built jac_exprs - solver.py builds symbolic Jacobian once, compiles once, passes to all consumers (_monolithic_solve, count_dof, diagnostics) - Automatic fallback: if codegen fails, tree-walk eval is used Expected performance impact: - ~10-20x faster Jacobian evaluation (no recursive dispatch) - ~2-5x additional from CSE on quaternion-heavy systems - ~3x fewer entries evaluated via sparsity detection - Eliminates redundant diff().simplify() in DOF/diagnostics
309 lines
9.5 KiB
Python
309 lines
9.5 KiB
Python
"""Compile Expr DAGs into flat Python functions for fast evaluation.
|
|
|
|
The compilation pipeline:
|
|
1. Collect all Expr nodes to be evaluated (residuals + Jacobian entries).
|
|
2. Identify common subexpressions (CSE) by ``id()`` — the Expr DAG
|
|
already shares node objects via ParamTable's Var instances.
|
|
3. Generate a single Python function body that computes CSE temps,
|
|
then fills ``r_vec`` and ``J`` arrays in-place.
|
|
4. Compile with ``compile()`` + ``exec()`` and return the callable.
|
|
|
|
The generated function signature is::
|
|
|
|
fn(env: dict[str, float], r_vec: ndarray, J: ndarray) -> None
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import math
|
|
from collections import Counter
|
|
from typing import Callable, List
|
|
|
|
import numpy as np
|
|
|
|
from .expr import Const, Expr, Var
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# Namespace injected into compiled functions.
|
|
_CODEGEN_NS = {
|
|
"_sin": math.sin,
|
|
"_cos": math.cos,
|
|
"_sqrt": math.sqrt,
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CSE (Common Subexpression Elimination)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _collect_nodes(expr: Expr, counts: Counter, visited: set[int]) -> None:
|
|
"""Walk *expr* and count how many times each node ``id()`` appears."""
|
|
eid = id(expr)
|
|
counts[eid] += 1
|
|
if eid in visited:
|
|
return
|
|
visited.add(eid)
|
|
|
|
# Recurse into children
|
|
if isinstance(expr, (Const, Var)):
|
|
return
|
|
if hasattr(expr, "child"):
|
|
_collect_nodes(expr.child, counts, visited)
|
|
elif hasattr(expr, "a"):
|
|
_collect_nodes(expr.a, counts, visited)
|
|
_collect_nodes(expr.b, counts, visited)
|
|
elif hasattr(expr, "base"):
|
|
_collect_nodes(expr.base, counts, visited)
|
|
_collect_nodes(expr.exp, counts, visited)
|
|
|
|
|
|
def _build_cse(
|
|
exprs: list[Expr],
|
|
) -> tuple[dict[int, str], list[tuple[str, Expr]]]:
|
|
"""Identify shared sub-trees and assign temporary variable names.
|
|
|
|
Returns:
|
|
id_to_temp: mapping from ``id(node)`` to temp variable name
|
|
temps_ordered: ``(temp_name, expr)`` pairs in dependency order
|
|
"""
|
|
counts: Counter = Counter()
|
|
visited: set[int] = set()
|
|
id_to_expr: dict[int, Expr] = {}
|
|
|
|
for expr in exprs:
|
|
_collect_nodes(expr, counts, visited)
|
|
|
|
# Map id -> Expr for nodes we visited
|
|
for expr in exprs:
|
|
_map_ids(expr, id_to_expr)
|
|
|
|
# Nodes referenced more than once and not trivial (Const/Var) become temps
|
|
shared_ids = set()
|
|
for eid, cnt in counts.items():
|
|
if cnt > 1:
|
|
node = id_to_expr.get(eid)
|
|
if node is not None and not isinstance(node, (Const, Var)):
|
|
shared_ids.add(eid)
|
|
|
|
if not shared_ids:
|
|
return {}, []
|
|
|
|
# Topological order: a temp must be computed before any temp that uses it.
|
|
# Walk each shared node, collect in post-order.
|
|
ordered_ids: list[int] = []
|
|
order_visited: set[int] = set()
|
|
|
|
def _topo(expr: Expr) -> None:
|
|
eid = id(expr)
|
|
if eid in order_visited:
|
|
return
|
|
order_visited.add(eid)
|
|
if isinstance(expr, (Const, Var)):
|
|
return
|
|
if hasattr(expr, "child"):
|
|
_topo(expr.child)
|
|
elif hasattr(expr, "a"):
|
|
_topo(expr.a)
|
|
_topo(expr.b)
|
|
elif hasattr(expr, "base"):
|
|
_topo(expr.base)
|
|
_topo(expr.exp)
|
|
if eid in shared_ids:
|
|
ordered_ids.append(eid)
|
|
|
|
for expr in exprs:
|
|
_topo(expr)
|
|
|
|
id_to_temp: dict[int, str] = {}
|
|
temps_ordered: list[tuple[str, Expr]] = []
|
|
for i, eid in enumerate(ordered_ids):
|
|
name = f"_c{i}"
|
|
id_to_temp[eid] = name
|
|
temps_ordered.append((name, id_to_expr[eid]))
|
|
|
|
return id_to_temp, temps_ordered
|
|
|
|
|
|
def _map_ids(expr: Expr, mapping: dict[int, Expr]) -> None:
|
|
"""Populate id -> Expr mapping for all nodes in *expr*."""
|
|
eid = id(expr)
|
|
if eid in mapping:
|
|
return
|
|
mapping[eid] = expr
|
|
if isinstance(expr, (Const, Var)):
|
|
return
|
|
if hasattr(expr, "child"):
|
|
_map_ids(expr.child, mapping)
|
|
elif hasattr(expr, "a"):
|
|
_map_ids(expr.a, mapping)
|
|
_map_ids(expr.b, mapping)
|
|
elif hasattr(expr, "base"):
|
|
_map_ids(expr.base, mapping)
|
|
_map_ids(expr.exp, mapping)
|
|
|
|
|
|
def _expr_to_code(expr: Expr, id_to_temp: dict[int, str]) -> str:
|
|
"""Emit code for *expr*, substituting temp names for shared nodes."""
|
|
eid = id(expr)
|
|
temp = id_to_temp.get(eid)
|
|
if temp is not None:
|
|
return temp
|
|
return expr.to_code()
|
|
|
|
|
|
def _expr_to_code_recursive(expr: Expr, id_to_temp: dict[int, str]) -> str:
|
|
"""Emit code for *expr*, recursing into children but respecting temps."""
|
|
eid = id(expr)
|
|
temp = id_to_temp.get(eid)
|
|
if temp is not None:
|
|
return temp
|
|
|
|
# For leaf nodes, just use to_code() directly
|
|
if isinstance(expr, (Const, Var)):
|
|
return expr.to_code()
|
|
|
|
# For non-leaf nodes, recurse into children with temp substitution
|
|
from .expr import Add, Cos, Div, Mul, Neg, Pow, Sin, Sqrt, Sub
|
|
|
|
if isinstance(expr, Neg):
|
|
return f"(-{_expr_to_code_recursive(expr.child, id_to_temp)})"
|
|
if isinstance(expr, Sin):
|
|
return f"_sin({_expr_to_code_recursive(expr.child, id_to_temp)})"
|
|
if isinstance(expr, Cos):
|
|
return f"_cos({_expr_to_code_recursive(expr.child, id_to_temp)})"
|
|
if isinstance(expr, Sqrt):
|
|
return f"_sqrt({_expr_to_code_recursive(expr.child, id_to_temp)})"
|
|
if isinstance(expr, Add):
|
|
a = _expr_to_code_recursive(expr.a, id_to_temp)
|
|
b = _expr_to_code_recursive(expr.b, id_to_temp)
|
|
return f"({a} + {b})"
|
|
if isinstance(expr, Sub):
|
|
a = _expr_to_code_recursive(expr.a, id_to_temp)
|
|
b = _expr_to_code_recursive(expr.b, id_to_temp)
|
|
return f"({a} - {b})"
|
|
if isinstance(expr, Mul):
|
|
a = _expr_to_code_recursive(expr.a, id_to_temp)
|
|
b = _expr_to_code_recursive(expr.b, id_to_temp)
|
|
return f"({a} * {b})"
|
|
if isinstance(expr, Div):
|
|
a = _expr_to_code_recursive(expr.a, id_to_temp)
|
|
b = _expr_to_code_recursive(expr.b, id_to_temp)
|
|
return f"({a} / {b})"
|
|
if isinstance(expr, Pow):
|
|
base = _expr_to_code_recursive(expr.base, id_to_temp)
|
|
exp = _expr_to_code_recursive(expr.exp, id_to_temp)
|
|
return f"({base} ** {exp})"
|
|
|
|
# Fallback — should not happen for known node types
|
|
return expr.to_code()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sparsity detection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _find_nonzero_entries(
|
|
jac_exprs: list[list[Expr]],
|
|
) -> list[tuple[int, int]]:
|
|
"""Return ``(row, col)`` pairs for non-zero Jacobian entries."""
|
|
nz = []
|
|
for i, row in enumerate(jac_exprs):
|
|
for j, expr in enumerate(row):
|
|
if isinstance(expr, Const) and expr.value == 0.0:
|
|
continue
|
|
nz.append((i, j))
|
|
return nz
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Code generation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def compile_system(
|
|
residuals: list[Expr],
|
|
jac_exprs: list[list[Expr]],
|
|
n_res: int,
|
|
n_free: int,
|
|
) -> Callable[[dict, np.ndarray, np.ndarray], None]:
|
|
"""Compile residuals + Jacobian into a single evaluation function.
|
|
|
|
Returns a callable ``fn(env, r_vec, J)`` that fills *r_vec* and *J*
|
|
in-place. *J* must be pre-zeroed by the caller (only non-zero
|
|
entries are written).
|
|
"""
|
|
# Detect non-zero Jacobian entries
|
|
nz_entries = _find_nonzero_entries(jac_exprs)
|
|
|
|
# Collect all expressions for CSE analysis
|
|
all_exprs: list[Expr] = list(residuals)
|
|
nz_jac_exprs: list[Expr] = [jac_exprs[i][j] for i, j in nz_entries]
|
|
all_exprs.extend(nz_jac_exprs)
|
|
|
|
# CSE
|
|
id_to_temp, temps_ordered = _build_cse(all_exprs)
|
|
|
|
# Generate function body
|
|
lines: list[str] = ["def _eval(env, r_vec, J):"]
|
|
|
|
# Temporaries — temporarily remove each temp's own id so its RHS
|
|
# is expanded rather than self-referencing.
|
|
for temp_name, temp_expr in temps_ordered:
|
|
eid = id(temp_expr)
|
|
saved = id_to_temp.pop(eid)
|
|
code = _expr_to_code_recursive(temp_expr, id_to_temp)
|
|
id_to_temp[eid] = saved
|
|
lines.append(f" {temp_name} = {code}")
|
|
|
|
# Residuals
|
|
for i, r in enumerate(residuals):
|
|
code = _expr_to_code_recursive(r, id_to_temp)
|
|
lines.append(f" r_vec[{i}] = {code}")
|
|
|
|
# Jacobian (sparse)
|
|
for idx, (i, j) in enumerate(nz_entries):
|
|
code = _expr_to_code_recursive(nz_jac_exprs[idx], id_to_temp)
|
|
lines.append(f" J[{i}, {j}] = {code}")
|
|
|
|
source = "\n".join(lines)
|
|
|
|
# Compile
|
|
code_obj = compile(source, "<kindred_codegen>", "exec")
|
|
ns = dict(_CODEGEN_NS)
|
|
exec(code_obj, ns)
|
|
|
|
fn = ns["_eval"]
|
|
|
|
n_temps = len(temps_ordered)
|
|
n_nz = len(nz_entries)
|
|
n_total = n_res * n_free
|
|
log.debug(
|
|
"codegen: compiled %d residuals + %d/%d Jacobian entries, %d CSE temps",
|
|
n_res,
|
|
n_nz,
|
|
n_total,
|
|
n_temps,
|
|
)
|
|
|
|
return fn
|
|
|
|
|
|
def try_compile_system(
|
|
residuals: list[Expr],
|
|
jac_exprs: list[list[Expr]],
|
|
n_res: int,
|
|
n_free: int,
|
|
) -> Callable[[dict, np.ndarray, np.ndarray], None] | None:
|
|
"""Compile with automatic fallback. Returns ``None`` on failure."""
|
|
try:
|
|
return compile_system(residuals, jac_exprs, n_res, n_free)
|
|
except Exception:
|
|
log.debug(
|
|
"codegen: compilation failed, falling back to tree-walk eval", exc_info=True
|
|
)
|
|
return None
|