Add a code generation pipeline that compiles Expr DAGs into flat Python functions, eliminating recursive tree-walk dispatch in the Newton-Raphson inner loop. Key changes: - Add to_code() method to all 11 Expr node types (expr.py) - New codegen.py module with CSE (common subexpression elimination), sparsity detection, and compile()/exec() compilation pipeline - Add ParamTable.env_ref() to avoid dict copies per iteration (params.py) - Newton and BFGS solvers accept pre-built jac_exprs and compiled_eval to avoid redundant diff/simplify and enable compiled evaluation - count_dof() and diagnostics accept pre-built jac_exprs - solver.py builds symbolic Jacobian once, compiles once, passes to all consumers (_monolithic_solve, count_dof, diagnostics) - Automatic fallback: if codegen fails, tree-walk eval is used Expected performance impact: - ~10-20x faster Jacobian evaluation (no recursive dispatch) - ~2-5x additional from CSE on quaternion-heavy systems - ~3x fewer entries evaluated via sparsity detection - Eliminates redundant diff().simplify() in DOF/diagnostics
468 lines
11 KiB
Python
468 lines
11 KiB
Python
"""Immutable expression DAG with eval, symbolic differentiation, and simplification."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
from typing import Dict
|
|
|
|
|
|
class Expr:
|
|
"""Base class for all expression nodes."""
|
|
|
|
__slots__ = ()
|
|
|
|
def eval(self, env: Dict[str, float]) -> float:
|
|
raise NotImplementedError
|
|
|
|
def diff(self, var: str) -> Expr:
|
|
raise NotImplementedError
|
|
|
|
def simplify(self) -> Expr:
|
|
return self
|
|
|
|
def vars(self) -> set[str]:
|
|
"""Return the set of variable names in this expression."""
|
|
raise NotImplementedError
|
|
|
|
def to_code(self) -> str:
|
|
"""Emit a Python arithmetic expression string.
|
|
|
|
The returned string, when evaluated with a dict ``env`` mapping
|
|
parameter names to floats (and ``_sin``, ``_cos``, ``_sqrt``
|
|
bound to their ``math`` equivalents), produces the same result
|
|
as ``self.eval(env)``.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
# -- operator overloads --------------------------------------------------
|
|
|
|
def __add__(self, other):
|
|
return Add(self, _wrap(other))
|
|
|
|
def __radd__(self, other):
|
|
return Add(_wrap(other), self)
|
|
|
|
def __sub__(self, other):
|
|
return Sub(self, _wrap(other))
|
|
|
|
def __rsub__(self, other):
|
|
return Sub(_wrap(other), self)
|
|
|
|
def __mul__(self, other):
|
|
return Mul(self, _wrap(other))
|
|
|
|
def __rmul__(self, other):
|
|
return Mul(_wrap(other), self)
|
|
|
|
def __truediv__(self, other):
|
|
return Div(self, _wrap(other))
|
|
|
|
def __rtruediv__(self, other):
|
|
return Div(_wrap(other), self)
|
|
|
|
def __neg__(self):
|
|
return Neg(self)
|
|
|
|
def __pow__(self, other):
|
|
return Pow(self, _wrap(other))
|
|
|
|
def __rpow__(self, other):
|
|
return Pow(_wrap(other), self)
|
|
|
|
|
|
def _wrap(x) -> Expr:
|
|
"""Coerce a number to Const."""
|
|
if isinstance(x, Expr):
|
|
return x
|
|
if isinstance(x, (int, float)):
|
|
return Const(float(x))
|
|
raise TypeError(f"Cannot coerce {type(x).__name__} to Expr")
|
|
|
|
|
|
# -- leaf nodes ---------------------------------------------------------------
|
|
|
|
|
|
class Const(Expr):
|
|
__slots__ = ("value",)
|
|
|
|
def __init__(self, value: float):
|
|
self.value = float(value)
|
|
|
|
def eval(self, env):
|
|
return self.value
|
|
|
|
def diff(self, var):
|
|
return ZERO
|
|
|
|
def simplify(self):
|
|
return self
|
|
|
|
def vars(self):
|
|
return set()
|
|
|
|
def to_code(self):
|
|
return repr(self.value)
|
|
|
|
def __repr__(self):
|
|
return f"Const({self.value})"
|
|
|
|
def __eq__(self, other):
|
|
return isinstance(other, Const) and self.value == other.value
|
|
|
|
def __hash__(self):
|
|
return hash(("Const", self.value))
|
|
|
|
|
|
class Var(Expr):
|
|
__slots__ = ("name",)
|
|
|
|
def __init__(self, name: str):
|
|
self.name = name
|
|
|
|
def eval(self, env):
|
|
return env[self.name]
|
|
|
|
def diff(self, var):
|
|
return ONE if var == self.name else ZERO
|
|
|
|
def simplify(self):
|
|
return self
|
|
|
|
def vars(self):
|
|
return {self.name}
|
|
|
|
def to_code(self):
|
|
return f"env[{self.name!r}]"
|
|
|
|
def __repr__(self):
|
|
return f"Var({self.name!r})"
|
|
|
|
def __eq__(self, other):
|
|
return isinstance(other, Var) and self.name == other.name
|
|
|
|
def __hash__(self):
|
|
return hash(("Var", self.name))
|
|
|
|
|
|
# -- unary nodes --------------------------------------------------------------
|
|
|
|
|
|
class Neg(Expr):
|
|
__slots__ = ("child",)
|
|
|
|
def __init__(self, child: Expr):
|
|
self.child = child
|
|
|
|
def eval(self, env):
|
|
return -self.child.eval(env)
|
|
|
|
def diff(self, var):
|
|
return Neg(self.child.diff(var))
|
|
|
|
def simplify(self):
|
|
c = self.child.simplify()
|
|
if isinstance(c, Const):
|
|
return Const(-c.value)
|
|
if isinstance(c, Neg):
|
|
return c.child
|
|
return Neg(c)
|
|
|
|
def vars(self):
|
|
return self.child.vars()
|
|
|
|
def to_code(self):
|
|
return f"(-{self.child.to_code()})"
|
|
|
|
def __repr__(self):
|
|
return f"Neg({self.child!r})"
|
|
|
|
|
|
class Sin(Expr):
|
|
__slots__ = ("child",)
|
|
|
|
def __init__(self, child: Expr):
|
|
self.child = child
|
|
|
|
def eval(self, env):
|
|
return math.sin(self.child.eval(env))
|
|
|
|
def diff(self, var):
|
|
# d/dx sin(f) = cos(f) * f'
|
|
return Mul(Cos(self.child), self.child.diff(var))
|
|
|
|
def simplify(self):
|
|
c = self.child.simplify()
|
|
if isinstance(c, Const):
|
|
return Const(math.sin(c.value))
|
|
return Sin(c)
|
|
|
|
def vars(self):
|
|
return self.child.vars()
|
|
|
|
def to_code(self):
|
|
return f"_sin({self.child.to_code()})"
|
|
|
|
def __repr__(self):
|
|
return f"Sin({self.child!r})"
|
|
|
|
|
|
class Cos(Expr):
|
|
__slots__ = ("child",)
|
|
|
|
def __init__(self, child: Expr):
|
|
self.child = child
|
|
|
|
def eval(self, env):
|
|
return math.cos(self.child.eval(env))
|
|
|
|
def diff(self, var):
|
|
# d/dx cos(f) = -sin(f) * f'
|
|
return Mul(Neg(Sin(self.child)), self.child.diff(var))
|
|
|
|
def simplify(self):
|
|
c = self.child.simplify()
|
|
if isinstance(c, Const):
|
|
return Const(math.cos(c.value))
|
|
return Cos(c)
|
|
|
|
def vars(self):
|
|
return self.child.vars()
|
|
|
|
def to_code(self):
|
|
return f"_cos({self.child.to_code()})"
|
|
|
|
def __repr__(self):
|
|
return f"Cos({self.child!r})"
|
|
|
|
|
|
class Sqrt(Expr):
|
|
__slots__ = ("child",)
|
|
|
|
def __init__(self, child: Expr):
|
|
self.child = child
|
|
|
|
def eval(self, env):
|
|
return math.sqrt(self.child.eval(env))
|
|
|
|
def diff(self, var):
|
|
# d/dx sqrt(f) = f' / (2 * sqrt(f))
|
|
return Div(self.child.diff(var), Mul(Const(2.0), Sqrt(self.child)))
|
|
|
|
def simplify(self):
|
|
c = self.child.simplify()
|
|
if isinstance(c, Const) and c.value >= 0:
|
|
return Const(math.sqrt(c.value))
|
|
return Sqrt(c)
|
|
|
|
def vars(self):
|
|
return self.child.vars()
|
|
|
|
def to_code(self):
|
|
return f"_sqrt({self.child.to_code()})"
|
|
|
|
def __repr__(self):
|
|
return f"Sqrt({self.child!r})"
|
|
|
|
|
|
# -- binary nodes -------------------------------------------------------------
|
|
|
|
|
|
class Add(Expr):
|
|
__slots__ = ("a", "b")
|
|
|
|
def __init__(self, a: Expr, b: Expr):
|
|
self.a = a
|
|
self.b = b
|
|
|
|
def eval(self, env):
|
|
return self.a.eval(env) + self.b.eval(env)
|
|
|
|
def diff(self, var):
|
|
return Add(self.a.diff(var), self.b.diff(var))
|
|
|
|
def simplify(self):
|
|
a = self.a.simplify()
|
|
b = self.b.simplify()
|
|
if isinstance(a, Const) and isinstance(b, Const):
|
|
return Const(a.value + b.value)
|
|
if isinstance(a, Const) and a.value == 0.0:
|
|
return b
|
|
if isinstance(b, Const) and b.value == 0.0:
|
|
return a
|
|
return Add(a, b)
|
|
|
|
def vars(self):
|
|
return self.a.vars() | self.b.vars()
|
|
|
|
def to_code(self):
|
|
return f"({self.a.to_code()} + {self.b.to_code()})"
|
|
|
|
def __repr__(self):
|
|
return f"Add({self.a!r}, {self.b!r})"
|
|
|
|
|
|
class Sub(Expr):
|
|
__slots__ = ("a", "b")
|
|
|
|
def __init__(self, a: Expr, b: Expr):
|
|
self.a = a
|
|
self.b = b
|
|
|
|
def eval(self, env):
|
|
return self.a.eval(env) - self.b.eval(env)
|
|
|
|
def diff(self, var):
|
|
return Sub(self.a.diff(var), self.b.diff(var))
|
|
|
|
def simplify(self):
|
|
a = self.a.simplify()
|
|
b = self.b.simplify()
|
|
if isinstance(a, Const) and isinstance(b, Const):
|
|
return Const(a.value - b.value)
|
|
if isinstance(b, Const) and b.value == 0.0:
|
|
return a
|
|
if isinstance(a, Const) and a.value == 0.0:
|
|
return Neg(b).simplify()
|
|
return Sub(a, b)
|
|
|
|
def vars(self):
|
|
return self.a.vars() | self.b.vars()
|
|
|
|
def to_code(self):
|
|
return f"({self.a.to_code()} - {self.b.to_code()})"
|
|
|
|
def __repr__(self):
|
|
return f"Sub({self.a!r}, {self.b!r})"
|
|
|
|
|
|
class Mul(Expr):
|
|
__slots__ = ("a", "b")
|
|
|
|
def __init__(self, a: Expr, b: Expr):
|
|
self.a = a
|
|
self.b = b
|
|
|
|
def eval(self, env):
|
|
return self.a.eval(env) * self.b.eval(env)
|
|
|
|
def diff(self, var):
|
|
# product rule: a'b + ab'
|
|
return Add(Mul(self.a.diff(var), self.b), Mul(self.a, self.b.diff(var)))
|
|
|
|
def simplify(self):
|
|
a = self.a.simplify()
|
|
b = self.b.simplify()
|
|
if isinstance(a, Const) and isinstance(b, Const):
|
|
return Const(a.value * b.value)
|
|
if isinstance(a, Const) and a.value == 0.0:
|
|
return ZERO
|
|
if isinstance(b, Const) and b.value == 0.0:
|
|
return ZERO
|
|
if isinstance(a, Const) and a.value == 1.0:
|
|
return b
|
|
if isinstance(b, Const) and b.value == 1.0:
|
|
return a
|
|
if isinstance(a, Const) and a.value == -1.0:
|
|
return Neg(b).simplify()
|
|
if isinstance(b, Const) and b.value == -1.0:
|
|
return Neg(a).simplify()
|
|
return Mul(a, b)
|
|
|
|
def vars(self):
|
|
return self.a.vars() | self.b.vars()
|
|
|
|
def to_code(self):
|
|
return f"({self.a.to_code()} * {self.b.to_code()})"
|
|
|
|
def __repr__(self):
|
|
return f"Mul({self.a!r}, {self.b!r})"
|
|
|
|
|
|
class Div(Expr):
|
|
__slots__ = ("a", "b")
|
|
|
|
def __init__(self, a: Expr, b: Expr):
|
|
self.a = a
|
|
self.b = b
|
|
|
|
def eval(self, env):
|
|
return self.a.eval(env) / self.b.eval(env)
|
|
|
|
def diff(self, var):
|
|
# quotient rule: (a'b - ab') / b^2
|
|
return Div(
|
|
Sub(Mul(self.a.diff(var), self.b), Mul(self.a, self.b.diff(var))),
|
|
Mul(self.b, self.b),
|
|
)
|
|
|
|
def simplify(self):
|
|
a = self.a.simplify()
|
|
b = self.b.simplify()
|
|
if isinstance(a, Const) and isinstance(b, Const) and b.value != 0.0:
|
|
return Const(a.value / b.value)
|
|
if isinstance(a, Const) and a.value == 0.0:
|
|
return ZERO
|
|
if isinstance(b, Const) and b.value == 1.0:
|
|
return a
|
|
return Div(a, b)
|
|
|
|
def vars(self):
|
|
return self.a.vars() | self.b.vars()
|
|
|
|
def to_code(self):
|
|
return f"({self.a.to_code()} / {self.b.to_code()})"
|
|
|
|
def __repr__(self):
|
|
return f"Div({self.a!r}, {self.b!r})"
|
|
|
|
|
|
class Pow(Expr):
|
|
__slots__ = ("base", "exp")
|
|
|
|
def __init__(self, base: Expr, exp: Expr):
|
|
self.base = base
|
|
self.exp = exp
|
|
|
|
def eval(self, env):
|
|
return self.base.eval(env) ** self.exp.eval(env)
|
|
|
|
def diff(self, var):
|
|
# For constant exponent: d/dx f^n = n * f^(n-1) * f'
|
|
# General case: d/dx f^g = f^g * (g' * ln(f) + g * f'/f)
|
|
# Phase 1: only support constant exponent
|
|
if isinstance(self.exp, Const):
|
|
n = self.exp.value
|
|
return Mul(
|
|
Mul(Const(n), Pow(self.base, Const(n - 1.0))), self.base.diff(var)
|
|
)
|
|
raise NotImplementedError("diff of Pow with non-constant exponent")
|
|
|
|
def simplify(self):
|
|
base = self.base.simplify()
|
|
exp = self.exp.simplify()
|
|
if isinstance(base, Const) and isinstance(exp, Const):
|
|
return Const(base.value**exp.value)
|
|
if isinstance(exp, Const):
|
|
if exp.value == 0.0:
|
|
return ONE
|
|
if exp.value == 1.0:
|
|
return base
|
|
if exp.value == 2.0:
|
|
return Mul(base, base).simplify()
|
|
return Pow(base, exp)
|
|
|
|
def vars(self):
|
|
return self.base.vars() | self.exp.vars()
|
|
|
|
def to_code(self):
|
|
return f"({self.base.to_code()} ** {self.exp.to_code()})"
|
|
|
|
def __repr__(self):
|
|
return f"Pow({self.base!r}, {self.exp!r})"
|
|
|
|
|
|
# -- sentinels ----------------------------------------------------------------
|
|
|
|
ZERO = Const(0.0)
|
|
ONE = Const(1.0)
|