"""Immutable expression DAG with eval, symbolic differentiation, and simplification.""" from __future__ import annotations import math from typing import Dict class Expr: """Base class for all expression nodes.""" __slots__ = () def eval(self, env: Dict[str, float]) -> float: raise NotImplementedError def diff(self, var: str) -> Expr: raise NotImplementedError def simplify(self) -> Expr: return self def vars(self) -> set[str]: """Return the set of variable names in this expression.""" raise NotImplementedError # -- operator overloads -------------------------------------------------- def __add__(self, other): return Add(self, _wrap(other)) def __radd__(self, other): return Add(_wrap(other), self) def __sub__(self, other): return Sub(self, _wrap(other)) def __rsub__(self, other): return Sub(_wrap(other), self) def __mul__(self, other): return Mul(self, _wrap(other)) def __rmul__(self, other): return Mul(_wrap(other), self) def __truediv__(self, other): return Div(self, _wrap(other)) def __rtruediv__(self, other): return Div(_wrap(other), self) def __neg__(self): return Neg(self) def __pow__(self, other): return Pow(self, _wrap(other)) def __rpow__(self, other): return Pow(_wrap(other), self) def _wrap(x) -> Expr: """Coerce a number to Const.""" if isinstance(x, Expr): return x if isinstance(x, (int, float)): return Const(float(x)) raise TypeError(f"Cannot coerce {type(x).__name__} to Expr") # -- leaf nodes --------------------------------------------------------------- class Const(Expr): __slots__ = ("value",) def __init__(self, value: float): self.value = float(value) def eval(self, env): return self.value def diff(self, var): return ZERO def simplify(self): return self def vars(self): return set() def __repr__(self): return f"Const({self.value})" def __eq__(self, other): return isinstance(other, Const) and self.value == other.value def __hash__(self): return hash(("Const", self.value)) class Var(Expr): __slots__ = ("name",) def __init__(self, name: str): self.name = name def eval(self, env): return env[self.name] def diff(self, var): return ONE if var == self.name else ZERO def simplify(self): return self def vars(self): return {self.name} def __repr__(self): return f"Var({self.name!r})" def __eq__(self, other): return isinstance(other, Var) and self.name == other.name def __hash__(self): return hash(("Var", self.name)) # -- unary nodes -------------------------------------------------------------- class Neg(Expr): __slots__ = ("child",) def __init__(self, child: Expr): self.child = child def eval(self, env): return -self.child.eval(env) def diff(self, var): return Neg(self.child.diff(var)) def simplify(self): c = self.child.simplify() if isinstance(c, Const): return Const(-c.value) if isinstance(c, Neg): return c.child return Neg(c) def vars(self): return self.child.vars() def __repr__(self): return f"Neg({self.child!r})" class Sin(Expr): __slots__ = ("child",) def __init__(self, child: Expr): self.child = child def eval(self, env): return math.sin(self.child.eval(env)) def diff(self, var): # d/dx sin(f) = cos(f) * f' return Mul(Cos(self.child), self.child.diff(var)) def simplify(self): c = self.child.simplify() if isinstance(c, Const): return Const(math.sin(c.value)) return Sin(c) def vars(self): return self.child.vars() def __repr__(self): return f"Sin({self.child!r})" class Cos(Expr): __slots__ = ("child",) def __init__(self, child: Expr): self.child = child def eval(self, env): return math.cos(self.child.eval(env)) def diff(self, var): # d/dx cos(f) = -sin(f) * f' return Mul(Neg(Sin(self.child)), self.child.diff(var)) def simplify(self): c = self.child.simplify() if isinstance(c, Const): return Const(math.cos(c.value)) return Cos(c) def vars(self): return self.child.vars() def __repr__(self): return f"Cos({self.child!r})" class Sqrt(Expr): __slots__ = ("child",) def __init__(self, child: Expr): self.child = child def eval(self, env): return math.sqrt(self.child.eval(env)) def diff(self, var): # d/dx sqrt(f) = f' / (2 * sqrt(f)) return Div(self.child.diff(var), Mul(Const(2.0), Sqrt(self.child))) def simplify(self): c = self.child.simplify() if isinstance(c, Const) and c.value >= 0: return Const(math.sqrt(c.value)) return Sqrt(c) def vars(self): return self.child.vars() def __repr__(self): return f"Sqrt({self.child!r})" # -- binary nodes ------------------------------------------------------------- class Add(Expr): __slots__ = ("a", "b") def __init__(self, a: Expr, b: Expr): self.a = a self.b = b def eval(self, env): return self.a.eval(env) + self.b.eval(env) def diff(self, var): return Add(self.a.diff(var), self.b.diff(var)) def simplify(self): a = self.a.simplify() b = self.b.simplify() if isinstance(a, Const) and isinstance(b, Const): return Const(a.value + b.value) if isinstance(a, Const) and a.value == 0.0: return b if isinstance(b, Const) and b.value == 0.0: return a return Add(a, b) def vars(self): return self.a.vars() | self.b.vars() def __repr__(self): return f"Add({self.a!r}, {self.b!r})" class Sub(Expr): __slots__ = ("a", "b") def __init__(self, a: Expr, b: Expr): self.a = a self.b = b def eval(self, env): return self.a.eval(env) - self.b.eval(env) def diff(self, var): return Sub(self.a.diff(var), self.b.diff(var)) def simplify(self): a = self.a.simplify() b = self.b.simplify() if isinstance(a, Const) and isinstance(b, Const): return Const(a.value - b.value) if isinstance(b, Const) and b.value == 0.0: return a if isinstance(a, Const) and a.value == 0.0: return Neg(b).simplify() return Sub(a, b) def vars(self): return self.a.vars() | self.b.vars() def __repr__(self): return f"Sub({self.a!r}, {self.b!r})" class Mul(Expr): __slots__ = ("a", "b") def __init__(self, a: Expr, b: Expr): self.a = a self.b = b def eval(self, env): return self.a.eval(env) * self.b.eval(env) def diff(self, var): # product rule: a'b + ab' return Add(Mul(self.a.diff(var), self.b), Mul(self.a, self.b.diff(var))) def simplify(self): a = self.a.simplify() b = self.b.simplify() if isinstance(a, Const) and isinstance(b, Const): return Const(a.value * b.value) if isinstance(a, Const) and a.value == 0.0: return ZERO if isinstance(b, Const) and b.value == 0.0: return ZERO if isinstance(a, Const) and a.value == 1.0: return b if isinstance(b, Const) and b.value == 1.0: return a if isinstance(a, Const) and a.value == -1.0: return Neg(b).simplify() if isinstance(b, Const) and b.value == -1.0: return Neg(a).simplify() return Mul(a, b) def vars(self): return self.a.vars() | self.b.vars() def __repr__(self): return f"Mul({self.a!r}, {self.b!r})" class Div(Expr): __slots__ = ("a", "b") def __init__(self, a: Expr, b: Expr): self.a = a self.b = b def eval(self, env): return self.a.eval(env) / self.b.eval(env) def diff(self, var): # quotient rule: (a'b - ab') / b^2 return Div( Sub(Mul(self.a.diff(var), self.b), Mul(self.a, self.b.diff(var))), Mul(self.b, self.b), ) def simplify(self): a = self.a.simplify() b = self.b.simplify() if isinstance(a, Const) and isinstance(b, Const) and b.value != 0.0: return Const(a.value / b.value) if isinstance(a, Const) and a.value == 0.0: return ZERO if isinstance(b, Const) and b.value == 1.0: return a return Div(a, b) def vars(self): return self.a.vars() | self.b.vars() def __repr__(self): return f"Div({self.a!r}, {self.b!r})" class Pow(Expr): __slots__ = ("base", "exp") def __init__(self, base: Expr, exp: Expr): self.base = base self.exp = exp def eval(self, env): return self.base.eval(env) ** self.exp.eval(env) def diff(self, var): # For constant exponent: d/dx f^n = n * f^(n-1) * f' # General case: d/dx f^g = f^g * (g' * ln(f) + g * f'/f) # Phase 1: only support constant exponent if isinstance(self.exp, Const): n = self.exp.value return Mul( Mul(Const(n), Pow(self.base, Const(n - 1.0))), self.base.diff(var) ) raise NotImplementedError("diff of Pow with non-constant exponent") def simplify(self): base = self.base.simplify() exp = self.exp.simplify() if isinstance(base, Const) and isinstance(exp, Const): return Const(base.value**exp.value) if isinstance(exp, Const): if exp.value == 0.0: return ONE if exp.value == 1.0: return base if exp.value == 2.0: return Mul(base, base).simplify() return Pow(base, exp) def vars(self): return self.base.vars() | self.exp.vars() def __repr__(self): return f"Pow({self.base!r}, {self.exp!r})" # -- sentinels ---------------------------------------------------------------- ZERO = Const(0.0) ONE = Const(1.0)