"""Pre-solve passes to reduce the system before Newton-Raphson. 1. Substitution pass — replace fixed-parameter Var nodes with Const values. 2. Single-equation pass — if a residual mentions exactly one free variable, solve it analytically (when possible) and fix that variable. """ from __future__ import annotations from typing import List from .expr import ZERO, Add, Const, Expr, Mul, Neg, Sub, Var from .params import ParamTable def substitution_pass(residuals: List[Expr], params: ParamTable) -> List[Expr]: """Replace fixed Var nodes with their constant values. Returns a new list of simplified residuals. """ env = params.get_env() fixed = {name for name in env if params.is_fixed(name)} if not fixed: return residuals return [_substitute(r, env, fixed).simplify() for r in residuals] def _substitute(expr: Expr, env: dict[str, float], fixed: set[str]) -> Expr: """Recursively replace Var nodes in *fixed* with Const values.""" if isinstance(expr, Const): return expr if isinstance(expr, Var): if expr.name in fixed: return Const(env[expr.name]) return expr if isinstance(expr, Neg): return Neg(_substitute(expr.child, env, fixed)) if isinstance(expr, Add): return Add(_substitute(expr.a, env, fixed), _substitute(expr.b, env, fixed)) if isinstance(expr, Sub): return Sub(_substitute(expr.a, env, fixed), _substitute(expr.b, env, fixed)) if isinstance(expr, Mul): return Mul(_substitute(expr.a, env, fixed), _substitute(expr.b, env, fixed)) # For all other node types, rebuild with substituted children from .expr import Cos, Div, Pow, Sin, Sqrt if isinstance(expr, Div): return Div(_substitute(expr.a, env, fixed), _substitute(expr.b, env, fixed)) if isinstance(expr, Pow): return Pow( _substitute(expr.base, env, fixed), _substitute(expr.exp, env, fixed) ) if isinstance(expr, Sin): return Sin(_substitute(expr.child, env, fixed)) if isinstance(expr, Cos): return Cos(_substitute(expr.child, env, fixed)) if isinstance(expr, Sqrt): return Sqrt(_substitute(expr.child, env, fixed)) return expr def single_equation_pass(residuals: List[Expr], params: ParamTable) -> List[Expr]: """Solve residuals that depend on a single free variable. Handles linear cases: a*x + b = 0 → x = -b/a. Repeats until no more single-variable residuals can be solved. Returns the remaining unsolved residuals. """ changed = True remaining = list(residuals) while changed: changed = False new_remaining = [] for r in remaining: free_vars = r.vars() & set(params.free_names()) if len(free_vars) == 1: name = next(iter(free_vars)) solved = _try_solve_linear(r, name, params) if solved: params.fix(name) # Re-substitute newly fixed variable in remaining remaining_after = [] env = params.get_env() fixed = {name} for rem in new_remaining: remaining_after.append(_substitute(rem, env, fixed).simplify()) new_remaining = remaining_after changed = True continue new_remaining.append(r) remaining = new_remaining return remaining def _try_solve_linear(expr: Expr, var_name: str, params: ParamTable) -> bool: """Try to solve expr==0 for var_name assuming linear dependence. If expr = a*var + b where a,b are constants, sets var = -b/a. Returns True on success. """ env = params.get_env() # Evaluate derivative w.r.t. var (should be constant if linear) deriv = expr.diff(var_name).simplify() # Check that derivative has no free variables (i.e. is truly constant) if deriv.vars() & set(params.free_names()): return False a = deriv.eval(env) if abs(a) < 1e-15: return False # Evaluate expr at current value f = expr.eval(env) # x_new = x_current - f/a x_cur = params.get_value(var_name) x_new = x_cur - f / a params.set_value(var_name, x_new) return True