"""Compile Expr DAGs into flat Python functions for fast evaluation. The compilation pipeline: 1. Collect all Expr nodes to be evaluated (residuals + Jacobian entries). 2. Identify common subexpressions (CSE) by ``id()`` — the Expr DAG already shares node objects via ParamTable's Var instances. 3. Generate a single Python function body that computes CSE temps, then fills ``r_vec`` and ``J`` arrays in-place. 4. Compile with ``compile()`` + ``exec()`` and return the callable. The generated function signature is:: fn(env: dict[str, float], r_vec: ndarray, J: ndarray) -> None """ from __future__ import annotations import logging import math from collections import Counter from typing import Callable, List import numpy as np from .expr import Const, Expr, Var log = logging.getLogger(__name__) # Namespace injected into compiled functions. _CODEGEN_NS = { "_sin": math.sin, "_cos": math.cos, "_sqrt": math.sqrt, } # --------------------------------------------------------------------------- # CSE (Common Subexpression Elimination) # --------------------------------------------------------------------------- def _collect_nodes(expr: Expr, counts: Counter, visited: set[int]) -> None: """Walk *expr* and count how many times each node ``id()`` appears.""" eid = id(expr) counts[eid] += 1 if eid in visited: return visited.add(eid) # Recurse into children if isinstance(expr, (Const, Var)): return if hasattr(expr, "child"): _collect_nodes(expr.child, counts, visited) elif hasattr(expr, "a"): _collect_nodes(expr.a, counts, visited) _collect_nodes(expr.b, counts, visited) elif hasattr(expr, "base"): _collect_nodes(expr.base, counts, visited) _collect_nodes(expr.exp, counts, visited) def _build_cse( exprs: list[Expr], ) -> tuple[dict[int, str], list[tuple[str, Expr]]]: """Identify shared sub-trees and assign temporary variable names. Returns: id_to_temp: mapping from ``id(node)`` to temp variable name temps_ordered: ``(temp_name, expr)`` pairs in dependency order """ counts: Counter = Counter() visited: set[int] = set() id_to_expr: dict[int, Expr] = {} for expr in exprs: _collect_nodes(expr, counts, visited) # Map id -> Expr for nodes we visited for expr in exprs: _map_ids(expr, id_to_expr) # Nodes referenced more than once and not trivial (Const/Var) become temps shared_ids = set() for eid, cnt in counts.items(): if cnt > 1: node = id_to_expr.get(eid) if node is not None and not isinstance(node, (Const, Var)): shared_ids.add(eid) if not shared_ids: return {}, [] # Topological order: a temp must be computed before any temp that uses it. # Walk each shared node, collect in post-order. ordered_ids: list[int] = [] order_visited: set[int] = set() def _topo(expr: Expr) -> None: eid = id(expr) if eid in order_visited: return order_visited.add(eid) if isinstance(expr, (Const, Var)): return if hasattr(expr, "child"): _topo(expr.child) elif hasattr(expr, "a"): _topo(expr.a) _topo(expr.b) elif hasattr(expr, "base"): _topo(expr.base) _topo(expr.exp) if eid in shared_ids: ordered_ids.append(eid) for expr in exprs: _topo(expr) id_to_temp: dict[int, str] = {} temps_ordered: list[tuple[str, Expr]] = [] for i, eid in enumerate(ordered_ids): name = f"_c{i}" id_to_temp[eid] = name temps_ordered.append((name, id_to_expr[eid])) return id_to_temp, temps_ordered def _map_ids(expr: Expr, mapping: dict[int, Expr]) -> None: """Populate id -> Expr mapping for all nodes in *expr*.""" eid = id(expr) if eid in mapping: return mapping[eid] = expr if isinstance(expr, (Const, Var)): return if hasattr(expr, "child"): _map_ids(expr.child, mapping) elif hasattr(expr, "a"): _map_ids(expr.a, mapping) _map_ids(expr.b, mapping) elif hasattr(expr, "base"): _map_ids(expr.base, mapping) _map_ids(expr.exp, mapping) def _expr_to_code(expr: Expr, id_to_temp: dict[int, str]) -> str: """Emit code for *expr*, substituting temp names for shared nodes.""" eid = id(expr) temp = id_to_temp.get(eid) if temp is not None: return temp return expr.to_code() def _expr_to_code_recursive(expr: Expr, id_to_temp: dict[int, str]) -> str: """Emit code for *expr*, recursing into children but respecting temps.""" eid = id(expr) temp = id_to_temp.get(eid) if temp is not None: return temp # For leaf nodes, just use to_code() directly if isinstance(expr, (Const, Var)): return expr.to_code() # For non-leaf nodes, recurse into children with temp substitution from .expr import Add, Cos, Div, Mul, Neg, Pow, Sin, Sqrt, Sub if isinstance(expr, Neg): return f"(-{_expr_to_code_recursive(expr.child, id_to_temp)})" if isinstance(expr, Sin): return f"_sin({_expr_to_code_recursive(expr.child, id_to_temp)})" if isinstance(expr, Cos): return f"_cos({_expr_to_code_recursive(expr.child, id_to_temp)})" if isinstance(expr, Sqrt): return f"_sqrt({_expr_to_code_recursive(expr.child, id_to_temp)})" if isinstance(expr, Add): a = _expr_to_code_recursive(expr.a, id_to_temp) b = _expr_to_code_recursive(expr.b, id_to_temp) return f"({a} + {b})" if isinstance(expr, Sub): a = _expr_to_code_recursive(expr.a, id_to_temp) b = _expr_to_code_recursive(expr.b, id_to_temp) return f"({a} - {b})" if isinstance(expr, Mul): a = _expr_to_code_recursive(expr.a, id_to_temp) b = _expr_to_code_recursive(expr.b, id_to_temp) return f"({a} * {b})" if isinstance(expr, Div): a = _expr_to_code_recursive(expr.a, id_to_temp) b = _expr_to_code_recursive(expr.b, id_to_temp) return f"({a} / {b})" if isinstance(expr, Pow): base = _expr_to_code_recursive(expr.base, id_to_temp) exp = _expr_to_code_recursive(expr.exp, id_to_temp) return f"({base} ** {exp})" # Fallback — should not happen for known node types return expr.to_code() # --------------------------------------------------------------------------- # Sparsity detection # --------------------------------------------------------------------------- def _find_nonzero_entries( jac_exprs: list[list[Expr]], ) -> list[tuple[int, int]]: """Return ``(row, col)`` pairs for non-zero Jacobian entries.""" nz = [] for i, row in enumerate(jac_exprs): for j, expr in enumerate(row): if isinstance(expr, Const) and expr.value == 0.0: continue nz.append((i, j)) return nz # --------------------------------------------------------------------------- # Code generation # --------------------------------------------------------------------------- def compile_system( residuals: list[Expr], jac_exprs: list[list[Expr]], n_res: int, n_free: int, ) -> Callable[[dict, np.ndarray, np.ndarray], None]: """Compile residuals + Jacobian into a single evaluation function. Returns a callable ``fn(env, r_vec, J)`` that fills *r_vec* and *J* in-place. *J* must be pre-zeroed by the caller (only non-zero entries are written). """ # Detect non-zero Jacobian entries nz_entries = _find_nonzero_entries(jac_exprs) # Collect all expressions for CSE analysis all_exprs: list[Expr] = list(residuals) nz_jac_exprs: list[Expr] = [jac_exprs[i][j] for i, j in nz_entries] all_exprs.extend(nz_jac_exprs) # CSE id_to_temp, temps_ordered = _build_cse(all_exprs) # Generate function body lines: list[str] = ["def _eval(env, r_vec, J):"] # Temporaries — temporarily remove each temp's own id so its RHS # is expanded rather than self-referencing. for temp_name, temp_expr in temps_ordered: eid = id(temp_expr) saved = id_to_temp.pop(eid) code = _expr_to_code_recursive(temp_expr, id_to_temp) id_to_temp[eid] = saved lines.append(f" {temp_name} = {code}") # Residuals for i, r in enumerate(residuals): code = _expr_to_code_recursive(r, id_to_temp) lines.append(f" r_vec[{i}] = {code}") # Jacobian (sparse) for idx, (i, j) in enumerate(nz_entries): code = _expr_to_code_recursive(nz_jac_exprs[idx], id_to_temp) lines.append(f" J[{i}, {j}] = {code}") source = "\n".join(lines) # Compile code_obj = compile(source, "", "exec") ns = dict(_CODEGEN_NS) exec(code_obj, ns) fn = ns["_eval"] n_temps = len(temps_ordered) n_nz = len(nz_entries) n_total = n_res * n_free log.debug( "codegen: compiled %d residuals + %d/%d Jacobian entries, %d CSE temps", n_res, n_nz, n_total, n_temps, ) return fn def try_compile_system( residuals: list[Expr], jac_exprs: list[list[Expr]], n_res: int, n_free: int, ) -> Callable[[dict, np.ndarray, np.ndarray], None] | None: """Compile with automatic fallback. Returns ``None`` on failure.""" try: return compile_system(residuals, jac_exprs, n_res, n_free) except Exception: log.debug( "codegen: compilation failed, falling back to tree-walk eval", exc_info=True ) return None