Files
solver/kindred_solver/codegen.py
forbes-0023 64b1e24467 feat(solver): compile symbolic Jacobian to flat Python for fast evaluation
Add a code generation pipeline that compiles Expr DAGs into flat Python
functions, eliminating recursive tree-walk dispatch in the Newton-Raphson
inner loop.

Key changes:
- Add to_code() method to all 11 Expr node types (expr.py)
- New codegen.py module with CSE (common subexpression elimination),
  sparsity detection, and compile()/exec() compilation pipeline
- Add ParamTable.env_ref() to avoid dict copies per iteration (params.py)
- Newton and BFGS solvers accept pre-built jac_exprs and compiled_eval
  to avoid redundant diff/simplify and enable compiled evaluation
- count_dof() and diagnostics accept pre-built jac_exprs
- solver.py builds symbolic Jacobian once, compiles once, passes to all
  consumers (_monolithic_solve, count_dof, diagnostics)
- Automatic fallback: if codegen fails, tree-walk eval is used

Expected performance impact:
- ~10-20x faster Jacobian evaluation (no recursive dispatch)
- ~2-5x additional from CSE on quaternion-heavy systems
- ~3x fewer entries evaluated via sparsity detection
- Eliminates redundant diff().simplify() in DOF/diagnostics
2026-02-21 11:22:36 -06:00

309 lines
9.5 KiB
Python

"""Compile Expr DAGs into flat Python functions for fast evaluation.
The compilation pipeline:
1. Collect all Expr nodes to be evaluated (residuals + Jacobian entries).
2. Identify common subexpressions (CSE) by ``id()`` — the Expr DAG
already shares node objects via ParamTable's Var instances.
3. Generate a single Python function body that computes CSE temps,
then fills ``r_vec`` and ``J`` arrays in-place.
4. Compile with ``compile()`` + ``exec()`` and return the callable.
The generated function signature is::
fn(env: dict[str, float], r_vec: ndarray, J: ndarray) -> None
"""
from __future__ import annotations
import logging
import math
from collections import Counter
from typing import Callable, List
import numpy as np
from .expr import Const, Expr, Var
log = logging.getLogger(__name__)
# Namespace injected into compiled functions.
_CODEGEN_NS = {
"_sin": math.sin,
"_cos": math.cos,
"_sqrt": math.sqrt,
}
# ---------------------------------------------------------------------------
# CSE (Common Subexpression Elimination)
# ---------------------------------------------------------------------------
def _collect_nodes(expr: Expr, counts: Counter, visited: set[int]) -> None:
"""Walk *expr* and count how many times each node ``id()`` appears."""
eid = id(expr)
counts[eid] += 1
if eid in visited:
return
visited.add(eid)
# Recurse into children
if isinstance(expr, (Const, Var)):
return
if hasattr(expr, "child"):
_collect_nodes(expr.child, counts, visited)
elif hasattr(expr, "a"):
_collect_nodes(expr.a, counts, visited)
_collect_nodes(expr.b, counts, visited)
elif hasattr(expr, "base"):
_collect_nodes(expr.base, counts, visited)
_collect_nodes(expr.exp, counts, visited)
def _build_cse(
exprs: list[Expr],
) -> tuple[dict[int, str], list[tuple[str, Expr]]]:
"""Identify shared sub-trees and assign temporary variable names.
Returns:
id_to_temp: mapping from ``id(node)`` to temp variable name
temps_ordered: ``(temp_name, expr)`` pairs in dependency order
"""
counts: Counter = Counter()
visited: set[int] = set()
id_to_expr: dict[int, Expr] = {}
for expr in exprs:
_collect_nodes(expr, counts, visited)
# Map id -> Expr for nodes we visited
for expr in exprs:
_map_ids(expr, id_to_expr)
# Nodes referenced more than once and not trivial (Const/Var) become temps
shared_ids = set()
for eid, cnt in counts.items():
if cnt > 1:
node = id_to_expr.get(eid)
if node is not None and not isinstance(node, (Const, Var)):
shared_ids.add(eid)
if not shared_ids:
return {}, []
# Topological order: a temp must be computed before any temp that uses it.
# Walk each shared node, collect in post-order.
ordered_ids: list[int] = []
order_visited: set[int] = set()
def _topo(expr: Expr) -> None:
eid = id(expr)
if eid in order_visited:
return
order_visited.add(eid)
if isinstance(expr, (Const, Var)):
return
if hasattr(expr, "child"):
_topo(expr.child)
elif hasattr(expr, "a"):
_topo(expr.a)
_topo(expr.b)
elif hasattr(expr, "base"):
_topo(expr.base)
_topo(expr.exp)
if eid in shared_ids:
ordered_ids.append(eid)
for expr in exprs:
_topo(expr)
id_to_temp: dict[int, str] = {}
temps_ordered: list[tuple[str, Expr]] = []
for i, eid in enumerate(ordered_ids):
name = f"_c{i}"
id_to_temp[eid] = name
temps_ordered.append((name, id_to_expr[eid]))
return id_to_temp, temps_ordered
def _map_ids(expr: Expr, mapping: dict[int, Expr]) -> None:
"""Populate id -> Expr mapping for all nodes in *expr*."""
eid = id(expr)
if eid in mapping:
return
mapping[eid] = expr
if isinstance(expr, (Const, Var)):
return
if hasattr(expr, "child"):
_map_ids(expr.child, mapping)
elif hasattr(expr, "a"):
_map_ids(expr.a, mapping)
_map_ids(expr.b, mapping)
elif hasattr(expr, "base"):
_map_ids(expr.base, mapping)
_map_ids(expr.exp, mapping)
def _expr_to_code(expr: Expr, id_to_temp: dict[int, str]) -> str:
"""Emit code for *expr*, substituting temp names for shared nodes."""
eid = id(expr)
temp = id_to_temp.get(eid)
if temp is not None:
return temp
return expr.to_code()
def _expr_to_code_recursive(expr: Expr, id_to_temp: dict[int, str]) -> str:
"""Emit code for *expr*, recursing into children but respecting temps."""
eid = id(expr)
temp = id_to_temp.get(eid)
if temp is not None:
return temp
# For leaf nodes, just use to_code() directly
if isinstance(expr, (Const, Var)):
return expr.to_code()
# For non-leaf nodes, recurse into children with temp substitution
from .expr import Add, Cos, Div, Mul, Neg, Pow, Sin, Sqrt, Sub
if isinstance(expr, Neg):
return f"(-{_expr_to_code_recursive(expr.child, id_to_temp)})"
if isinstance(expr, Sin):
return f"_sin({_expr_to_code_recursive(expr.child, id_to_temp)})"
if isinstance(expr, Cos):
return f"_cos({_expr_to_code_recursive(expr.child, id_to_temp)})"
if isinstance(expr, Sqrt):
return f"_sqrt({_expr_to_code_recursive(expr.child, id_to_temp)})"
if isinstance(expr, Add):
a = _expr_to_code_recursive(expr.a, id_to_temp)
b = _expr_to_code_recursive(expr.b, id_to_temp)
return f"({a} + {b})"
if isinstance(expr, Sub):
a = _expr_to_code_recursive(expr.a, id_to_temp)
b = _expr_to_code_recursive(expr.b, id_to_temp)
return f"({a} - {b})"
if isinstance(expr, Mul):
a = _expr_to_code_recursive(expr.a, id_to_temp)
b = _expr_to_code_recursive(expr.b, id_to_temp)
return f"({a} * {b})"
if isinstance(expr, Div):
a = _expr_to_code_recursive(expr.a, id_to_temp)
b = _expr_to_code_recursive(expr.b, id_to_temp)
return f"({a} / {b})"
if isinstance(expr, Pow):
base = _expr_to_code_recursive(expr.base, id_to_temp)
exp = _expr_to_code_recursive(expr.exp, id_to_temp)
return f"({base} ** {exp})"
# Fallback — should not happen for known node types
return expr.to_code()
# ---------------------------------------------------------------------------
# Sparsity detection
# ---------------------------------------------------------------------------
def _find_nonzero_entries(
jac_exprs: list[list[Expr]],
) -> list[tuple[int, int]]:
"""Return ``(row, col)`` pairs for non-zero Jacobian entries."""
nz = []
for i, row in enumerate(jac_exprs):
for j, expr in enumerate(row):
if isinstance(expr, Const) and expr.value == 0.0:
continue
nz.append((i, j))
return nz
# ---------------------------------------------------------------------------
# Code generation
# ---------------------------------------------------------------------------
def compile_system(
residuals: list[Expr],
jac_exprs: list[list[Expr]],
n_res: int,
n_free: int,
) -> Callable[[dict, np.ndarray, np.ndarray], None]:
"""Compile residuals + Jacobian into a single evaluation function.
Returns a callable ``fn(env, r_vec, J)`` that fills *r_vec* and *J*
in-place. *J* must be pre-zeroed by the caller (only non-zero
entries are written).
"""
# Detect non-zero Jacobian entries
nz_entries = _find_nonzero_entries(jac_exprs)
# Collect all expressions for CSE analysis
all_exprs: list[Expr] = list(residuals)
nz_jac_exprs: list[Expr] = [jac_exprs[i][j] for i, j in nz_entries]
all_exprs.extend(nz_jac_exprs)
# CSE
id_to_temp, temps_ordered = _build_cse(all_exprs)
# Generate function body
lines: list[str] = ["def _eval(env, r_vec, J):"]
# Temporaries — temporarily remove each temp's own id so its RHS
# is expanded rather than self-referencing.
for temp_name, temp_expr in temps_ordered:
eid = id(temp_expr)
saved = id_to_temp.pop(eid)
code = _expr_to_code_recursive(temp_expr, id_to_temp)
id_to_temp[eid] = saved
lines.append(f" {temp_name} = {code}")
# Residuals
for i, r in enumerate(residuals):
code = _expr_to_code_recursive(r, id_to_temp)
lines.append(f" r_vec[{i}] = {code}")
# Jacobian (sparse)
for idx, (i, j) in enumerate(nz_entries):
code = _expr_to_code_recursive(nz_jac_exprs[idx], id_to_temp)
lines.append(f" J[{i}, {j}] = {code}")
source = "\n".join(lines)
# Compile
code_obj = compile(source, "<kindred_codegen>", "exec")
ns = dict(_CODEGEN_NS)
exec(code_obj, ns)
fn = ns["_eval"]
n_temps = len(temps_ordered)
n_nz = len(nz_entries)
n_total = n_res * n_free
log.debug(
"codegen: compiled %d residuals + %d/%d Jacobian entries, %d CSE temps",
n_res,
n_nz,
n_total,
n_temps,
)
return fn
def try_compile_system(
residuals: list[Expr],
jac_exprs: list[list[Expr]],
n_res: int,
n_free: int,
) -> Callable[[dict, np.ndarray, np.ndarray], None] | None:
"""Compile with automatic fallback. Returns ``None`` on failure."""
try:
return compile_system(residuals, jac_exprs, n_res, n_free)
except Exception:
log.debug(
"codegen: compilation failed, falling back to tree-walk eval", exc_info=True
)
return None