Files
solver/kindred_solver/expr.py
forbes-0023 64b1e24467 feat(solver): compile symbolic Jacobian to flat Python for fast evaluation
Add a code generation pipeline that compiles Expr DAGs into flat Python
functions, eliminating recursive tree-walk dispatch in the Newton-Raphson
inner loop.

Key changes:
- Add to_code() method to all 11 Expr node types (expr.py)
- New codegen.py module with CSE (common subexpression elimination),
  sparsity detection, and compile()/exec() compilation pipeline
- Add ParamTable.env_ref() to avoid dict copies per iteration (params.py)
- Newton and BFGS solvers accept pre-built jac_exprs and compiled_eval
  to avoid redundant diff/simplify and enable compiled evaluation
- count_dof() and diagnostics accept pre-built jac_exprs
- solver.py builds symbolic Jacobian once, compiles once, passes to all
  consumers (_monolithic_solve, count_dof, diagnostics)
- Automatic fallback: if codegen fails, tree-walk eval is used

Expected performance impact:
- ~10-20x faster Jacobian evaluation (no recursive dispatch)
- ~2-5x additional from CSE on quaternion-heavy systems
- ~3x fewer entries evaluated via sparsity detection
- Eliminates redundant diff().simplify() in DOF/diagnostics
2026-02-21 11:22:36 -06:00

468 lines
11 KiB
Python

"""Immutable expression DAG with eval, symbolic differentiation, and simplification."""
from __future__ import annotations
import math
from typing import Dict
class Expr:
"""Base class for all expression nodes."""
__slots__ = ()
def eval(self, env: Dict[str, float]) -> float:
raise NotImplementedError
def diff(self, var: str) -> Expr:
raise NotImplementedError
def simplify(self) -> Expr:
return self
def vars(self) -> set[str]:
"""Return the set of variable names in this expression."""
raise NotImplementedError
def to_code(self) -> str:
"""Emit a Python arithmetic expression string.
The returned string, when evaluated with a dict ``env`` mapping
parameter names to floats (and ``_sin``, ``_cos``, ``_sqrt``
bound to their ``math`` equivalents), produces the same result
as ``self.eval(env)``.
"""
raise NotImplementedError
# -- operator overloads --------------------------------------------------
def __add__(self, other):
return Add(self, _wrap(other))
def __radd__(self, other):
return Add(_wrap(other), self)
def __sub__(self, other):
return Sub(self, _wrap(other))
def __rsub__(self, other):
return Sub(_wrap(other), self)
def __mul__(self, other):
return Mul(self, _wrap(other))
def __rmul__(self, other):
return Mul(_wrap(other), self)
def __truediv__(self, other):
return Div(self, _wrap(other))
def __rtruediv__(self, other):
return Div(_wrap(other), self)
def __neg__(self):
return Neg(self)
def __pow__(self, other):
return Pow(self, _wrap(other))
def __rpow__(self, other):
return Pow(_wrap(other), self)
def _wrap(x) -> Expr:
"""Coerce a number to Const."""
if isinstance(x, Expr):
return x
if isinstance(x, (int, float)):
return Const(float(x))
raise TypeError(f"Cannot coerce {type(x).__name__} to Expr")
# -- leaf nodes ---------------------------------------------------------------
class Const(Expr):
__slots__ = ("value",)
def __init__(self, value: float):
self.value = float(value)
def eval(self, env):
return self.value
def diff(self, var):
return ZERO
def simplify(self):
return self
def vars(self):
return set()
def to_code(self):
return repr(self.value)
def __repr__(self):
return f"Const({self.value})"
def __eq__(self, other):
return isinstance(other, Const) and self.value == other.value
def __hash__(self):
return hash(("Const", self.value))
class Var(Expr):
__slots__ = ("name",)
def __init__(self, name: str):
self.name = name
def eval(self, env):
return env[self.name]
def diff(self, var):
return ONE if var == self.name else ZERO
def simplify(self):
return self
def vars(self):
return {self.name}
def to_code(self):
return f"env[{self.name!r}]"
def __repr__(self):
return f"Var({self.name!r})"
def __eq__(self, other):
return isinstance(other, Var) and self.name == other.name
def __hash__(self):
return hash(("Var", self.name))
# -- unary nodes --------------------------------------------------------------
class Neg(Expr):
__slots__ = ("child",)
def __init__(self, child: Expr):
self.child = child
def eval(self, env):
return -self.child.eval(env)
def diff(self, var):
return Neg(self.child.diff(var))
def simplify(self):
c = self.child.simplify()
if isinstance(c, Const):
return Const(-c.value)
if isinstance(c, Neg):
return c.child
return Neg(c)
def vars(self):
return self.child.vars()
def to_code(self):
return f"(-{self.child.to_code()})"
def __repr__(self):
return f"Neg({self.child!r})"
class Sin(Expr):
__slots__ = ("child",)
def __init__(self, child: Expr):
self.child = child
def eval(self, env):
return math.sin(self.child.eval(env))
def diff(self, var):
# d/dx sin(f) = cos(f) * f'
return Mul(Cos(self.child), self.child.diff(var))
def simplify(self):
c = self.child.simplify()
if isinstance(c, Const):
return Const(math.sin(c.value))
return Sin(c)
def vars(self):
return self.child.vars()
def to_code(self):
return f"_sin({self.child.to_code()})"
def __repr__(self):
return f"Sin({self.child!r})"
class Cos(Expr):
__slots__ = ("child",)
def __init__(self, child: Expr):
self.child = child
def eval(self, env):
return math.cos(self.child.eval(env))
def diff(self, var):
# d/dx cos(f) = -sin(f) * f'
return Mul(Neg(Sin(self.child)), self.child.diff(var))
def simplify(self):
c = self.child.simplify()
if isinstance(c, Const):
return Const(math.cos(c.value))
return Cos(c)
def vars(self):
return self.child.vars()
def to_code(self):
return f"_cos({self.child.to_code()})"
def __repr__(self):
return f"Cos({self.child!r})"
class Sqrt(Expr):
__slots__ = ("child",)
def __init__(self, child: Expr):
self.child = child
def eval(self, env):
return math.sqrt(self.child.eval(env))
def diff(self, var):
# d/dx sqrt(f) = f' / (2 * sqrt(f))
return Div(self.child.diff(var), Mul(Const(2.0), Sqrt(self.child)))
def simplify(self):
c = self.child.simplify()
if isinstance(c, Const) and c.value >= 0:
return Const(math.sqrt(c.value))
return Sqrt(c)
def vars(self):
return self.child.vars()
def to_code(self):
return f"_sqrt({self.child.to_code()})"
def __repr__(self):
return f"Sqrt({self.child!r})"
# -- binary nodes -------------------------------------------------------------
class Add(Expr):
__slots__ = ("a", "b")
def __init__(self, a: Expr, b: Expr):
self.a = a
self.b = b
def eval(self, env):
return self.a.eval(env) + self.b.eval(env)
def diff(self, var):
return Add(self.a.diff(var), self.b.diff(var))
def simplify(self):
a = self.a.simplify()
b = self.b.simplify()
if isinstance(a, Const) and isinstance(b, Const):
return Const(a.value + b.value)
if isinstance(a, Const) and a.value == 0.0:
return b
if isinstance(b, Const) and b.value == 0.0:
return a
return Add(a, b)
def vars(self):
return self.a.vars() | self.b.vars()
def to_code(self):
return f"({self.a.to_code()} + {self.b.to_code()})"
def __repr__(self):
return f"Add({self.a!r}, {self.b!r})"
class Sub(Expr):
__slots__ = ("a", "b")
def __init__(self, a: Expr, b: Expr):
self.a = a
self.b = b
def eval(self, env):
return self.a.eval(env) - self.b.eval(env)
def diff(self, var):
return Sub(self.a.diff(var), self.b.diff(var))
def simplify(self):
a = self.a.simplify()
b = self.b.simplify()
if isinstance(a, Const) and isinstance(b, Const):
return Const(a.value - b.value)
if isinstance(b, Const) and b.value == 0.0:
return a
if isinstance(a, Const) and a.value == 0.0:
return Neg(b).simplify()
return Sub(a, b)
def vars(self):
return self.a.vars() | self.b.vars()
def to_code(self):
return f"({self.a.to_code()} - {self.b.to_code()})"
def __repr__(self):
return f"Sub({self.a!r}, {self.b!r})"
class Mul(Expr):
__slots__ = ("a", "b")
def __init__(self, a: Expr, b: Expr):
self.a = a
self.b = b
def eval(self, env):
return self.a.eval(env) * self.b.eval(env)
def diff(self, var):
# product rule: a'b + ab'
return Add(Mul(self.a.diff(var), self.b), Mul(self.a, self.b.diff(var)))
def simplify(self):
a = self.a.simplify()
b = self.b.simplify()
if isinstance(a, Const) and isinstance(b, Const):
return Const(a.value * b.value)
if isinstance(a, Const) and a.value == 0.0:
return ZERO
if isinstance(b, Const) and b.value == 0.0:
return ZERO
if isinstance(a, Const) and a.value == 1.0:
return b
if isinstance(b, Const) and b.value == 1.0:
return a
if isinstance(a, Const) and a.value == -1.0:
return Neg(b).simplify()
if isinstance(b, Const) and b.value == -1.0:
return Neg(a).simplify()
return Mul(a, b)
def vars(self):
return self.a.vars() | self.b.vars()
def to_code(self):
return f"({self.a.to_code()} * {self.b.to_code()})"
def __repr__(self):
return f"Mul({self.a!r}, {self.b!r})"
class Div(Expr):
__slots__ = ("a", "b")
def __init__(self, a: Expr, b: Expr):
self.a = a
self.b = b
def eval(self, env):
return self.a.eval(env) / self.b.eval(env)
def diff(self, var):
# quotient rule: (a'b - ab') / b^2
return Div(
Sub(Mul(self.a.diff(var), self.b), Mul(self.a, self.b.diff(var))),
Mul(self.b, self.b),
)
def simplify(self):
a = self.a.simplify()
b = self.b.simplify()
if isinstance(a, Const) and isinstance(b, Const) and b.value != 0.0:
return Const(a.value / b.value)
if isinstance(a, Const) and a.value == 0.0:
return ZERO
if isinstance(b, Const) and b.value == 1.0:
return a
return Div(a, b)
def vars(self):
return self.a.vars() | self.b.vars()
def to_code(self):
return f"({self.a.to_code()} / {self.b.to_code()})"
def __repr__(self):
return f"Div({self.a!r}, {self.b!r})"
class Pow(Expr):
__slots__ = ("base", "exp")
def __init__(self, base: Expr, exp: Expr):
self.base = base
self.exp = exp
def eval(self, env):
return self.base.eval(env) ** self.exp.eval(env)
def diff(self, var):
# For constant exponent: d/dx f^n = n * f^(n-1) * f'
# General case: d/dx f^g = f^g * (g' * ln(f) + g * f'/f)
# Phase 1: only support constant exponent
if isinstance(self.exp, Const):
n = self.exp.value
return Mul(
Mul(Const(n), Pow(self.base, Const(n - 1.0))), self.base.diff(var)
)
raise NotImplementedError("diff of Pow with non-constant exponent")
def simplify(self):
base = self.base.simplify()
exp = self.exp.simplify()
if isinstance(base, Const) and isinstance(exp, Const):
return Const(base.value**exp.value)
if isinstance(exp, Const):
if exp.value == 0.0:
return ONE
if exp.value == 1.0:
return base
if exp.value == 2.0:
return Mul(base, base).simplify()
return Pow(base, exp)
def vars(self):
return self.base.vars() | self.exp.vars()
def to_code(self):
return f"({self.base.to_code()} ** {self.exp.to_code()})"
def __repr__(self):
return f"Pow({self.base!r}, {self.exp!r})"
# -- sentinels ----------------------------------------------------------------
ZERO = Const(0.0)
ONE = Const(1.0)