"""Parameter table mapping named variables to Expr Var nodes and current values.""" from __future__ import annotations from typing import Dict, List import numpy as np from .expr import Var class ParamTable: """Central registry of solver variables. Each parameter has a name, a current numeric value, an associated :class:`Var` expression node, and a fixed/free flag. Grounded body parameters are marked fixed so the pre-pass can substitute them as constants. """ def __init__(self): self._vars: Dict[str, Var] = {} self._values: Dict[str, float] = {} self._fixed: set[str] = set() self._free_order: List[str] = [] # insertion-ordered free names def add(self, name: str, value: float = 0.0, fixed: bool = False) -> Var: """Create a parameter and return its Var node.""" if name in self._vars: raise ValueError(f"Duplicate parameter: {name}") v = Var(name) self._vars[name] = v self._values[name] = value if fixed: self._fixed.add(name) else: self._free_order.append(name) return v def get_var(self, name: str) -> Var: return self._vars[name] def is_fixed(self, name: str) -> bool: return name in self._fixed def fix(self, name: str): """Mark a parameter as fixed and remove it from the free list.""" self._fixed.add(name) if name in self._free_order: self._free_order.remove(name) def unfix(self, name: str): """Restore a fixed parameter to free status.""" if name in self._fixed: self._fixed.discard(name) if name not in self._free_order: self._free_order.append(name) def get_env(self) -> Dict[str, float]: """Return a snapshot of all current values (for Expr.eval).""" return dict(self._values) def env_ref(self) -> Dict[str, float]: """Return a direct reference to the internal values dict. Faster than :meth:`get_env` (no copy). Safe when the caller only reads during evaluation and mutates via :meth:`set_free_vector`. """ return self._values def free_names(self) -> List[str]: """Return ordered list of free (non-fixed) parameter names.""" return list(self._free_order) def n_free(self) -> int: return len(self._free_order) def get_value(self, name: str) -> float: return self._values[name] def set_value(self, name: str, value: float): self._values[name] = value def get_free_vector(self) -> np.ndarray: """Current free-parameter values as a 1-D array.""" return np.array([self._values[n] for n in self._free_order], dtype=np.float64) def set_free_vector(self, vec: np.ndarray): """Bulk-update free parameters from a 1-D array.""" for i, name in enumerate(self._free_order): self._values[name] = float(vec[i]) def snapshot(self) -> Dict[str, float]: """Capture current values as a checkpoint.""" return dict(self._values) def restore(self, snap: Dict[str, float]): """Restore parameter values from a checkpoint.""" for name, val in snap.items(): if name in self._values: self._values[name] = val def movement_cost( self, start: Dict[str, float], weights: Dict[str, float] | None = None, ) -> float: """Weighted sum of squared displacements from start.""" cost = 0.0 for name in self._free_order: w = weights.get(name, 1.0) if weights else 1.0 delta = self._values[name] - start.get(name, self._values[name]) cost += delta * delta * w return cost