Add a code generation pipeline that compiles Expr DAGs into flat Python functions, eliminating recursive tree-walk dispatch in the Newton-Raphson inner loop. Key changes: - Add to_code() method to all 11 Expr node types (expr.py) - New codegen.py module with CSE (common subexpression elimination), sparsity detection, and compile()/exec() compilation pipeline - Add ParamTable.env_ref() to avoid dict copies per iteration (params.py) - Newton and BFGS solvers accept pre-built jac_exprs and compiled_eval to avoid redundant diff/simplify and enable compiled evaluation - count_dof() and diagnostics accept pre-built jac_exprs - solver.py builds symbolic Jacobian once, compiles once, passes to all consumers (_monolithic_solve, count_dof, diagnostics) - Automatic fallback: if codegen fails, tree-walk eval is used Expected performance impact: - ~10-20x faster Jacobian evaluation (no recursive dispatch) - ~2-5x additional from CSE on quaternion-heavy systems - ~3x fewer entries evaluated via sparsity detection - Eliminates redundant diff().simplify() in DOF/diagnostics
115 lines
3.7 KiB
Python
115 lines
3.7 KiB
Python
"""Parameter table mapping named variables to Expr Var nodes and current values."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Dict, List
|
|
|
|
import numpy as np
|
|
|
|
from .expr import Var
|
|
|
|
|
|
class ParamTable:
|
|
"""Central registry of solver variables.
|
|
|
|
Each parameter has a name, a current numeric value, an associated
|
|
:class:`Var` expression node, and a fixed/free flag. Grounded
|
|
body parameters are marked fixed so the pre-pass can substitute
|
|
them as constants.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._vars: Dict[str, Var] = {}
|
|
self._values: Dict[str, float] = {}
|
|
self._fixed: set[str] = set()
|
|
self._free_order: List[str] = [] # insertion-ordered free names
|
|
|
|
def add(self, name: str, value: float = 0.0, fixed: bool = False) -> Var:
|
|
"""Create a parameter and return its Var node."""
|
|
if name in self._vars:
|
|
raise ValueError(f"Duplicate parameter: {name}")
|
|
v = Var(name)
|
|
self._vars[name] = v
|
|
self._values[name] = value
|
|
if fixed:
|
|
self._fixed.add(name)
|
|
else:
|
|
self._free_order.append(name)
|
|
return v
|
|
|
|
def get_var(self, name: str) -> Var:
|
|
return self._vars[name]
|
|
|
|
def is_fixed(self, name: str) -> bool:
|
|
return name in self._fixed
|
|
|
|
def fix(self, name: str):
|
|
"""Mark a parameter as fixed and remove it from the free list."""
|
|
self._fixed.add(name)
|
|
if name in self._free_order:
|
|
self._free_order.remove(name)
|
|
|
|
def unfix(self, name: str):
|
|
"""Restore a fixed parameter to free status."""
|
|
if name in self._fixed:
|
|
self._fixed.discard(name)
|
|
if name not in self._free_order:
|
|
self._free_order.append(name)
|
|
|
|
def get_env(self) -> Dict[str, float]:
|
|
"""Return a snapshot of all current values (for Expr.eval)."""
|
|
return dict(self._values)
|
|
|
|
def env_ref(self) -> Dict[str, float]:
|
|
"""Return a direct reference to the internal values dict.
|
|
|
|
Faster than :meth:`get_env` (no copy). Safe when the caller
|
|
only reads during evaluation and mutates via :meth:`set_free_vector`.
|
|
"""
|
|
return self._values
|
|
|
|
def free_names(self) -> List[str]:
|
|
"""Return ordered list of free (non-fixed) parameter names."""
|
|
return list(self._free_order)
|
|
|
|
def n_free(self) -> int:
|
|
return len(self._free_order)
|
|
|
|
def get_value(self, name: str) -> float:
|
|
return self._values[name]
|
|
|
|
def set_value(self, name: str, value: float):
|
|
self._values[name] = value
|
|
|
|
def get_free_vector(self) -> np.ndarray:
|
|
"""Current free-parameter values as a 1-D array."""
|
|
return np.array([self._values[n] for n in self._free_order], dtype=np.float64)
|
|
|
|
def set_free_vector(self, vec: np.ndarray):
|
|
"""Bulk-update free parameters from a 1-D array."""
|
|
for i, name in enumerate(self._free_order):
|
|
self._values[name] = float(vec[i])
|
|
|
|
def snapshot(self) -> Dict[str, float]:
|
|
"""Capture current values as a checkpoint."""
|
|
return dict(self._values)
|
|
|
|
def restore(self, snap: Dict[str, float]):
|
|
"""Restore parameter values from a checkpoint."""
|
|
for name, val in snap.items():
|
|
if name in self._values:
|
|
self._values[name] = val
|
|
|
|
def movement_cost(
|
|
self,
|
|
start: Dict[str, float],
|
|
weights: Dict[str, float] | None = None,
|
|
) -> float:
|
|
"""Weighted sum of squared displacements from start."""
|
|
cost = 0.0
|
|
for name in self._free_order:
|
|
w = weights.get(name, 1.0) if weights else 1.0
|
|
delta = self._values[name] - start.get(name, self._values[name])
|
|
cost += delta * delta * w
|
|
return cost
|