"""Parameter table mapping named variables to Expr Var nodes and current values.""" from __future__ import annotations from typing import Dict, List import numpy as np from .expr import Var class ParamTable: """Central registry of solver variables. Each parameter has a name, a current numeric value, an associated :class:`Var` expression node, and a fixed/free flag. Grounded body parameters are marked fixed so the pre-pass can substitute them as constants. """ def __init__(self): self._vars: Dict[str, Var] = {} self._values: Dict[str, float] = {} self._fixed: set[str] = set() self._free_order: List[str] = [] # insertion-ordered free names def add(self, name: str, value: float = 0.0, fixed: bool = False) -> Var: """Create a parameter and return its Var node.""" if name in self._vars: raise ValueError(f"Duplicate parameter: {name}") v = Var(name) self._vars[name] = v self._values[name] = value if fixed: self._fixed.add(name) else: self._free_order.append(name) return v def get_var(self, name: str) -> Var: return self._vars[name] def is_fixed(self, name: str) -> bool: return name in self._fixed def fix(self, name: str): """Mark a parameter as fixed and remove it from the free list.""" self._fixed.add(name) if name in self._free_order: self._free_order.remove(name) def get_env(self) -> Dict[str, float]: """Return a snapshot of all current values (for Expr.eval).""" return dict(self._values) def free_names(self) -> List[str]: """Return ordered list of free (non-fixed) parameter names.""" return list(self._free_order) def n_free(self) -> int: return len(self._free_order) def get_value(self, name: str) -> float: return self._values[name] def set_value(self, name: str, value: float): self._values[name] = value def get_free_vector(self) -> np.ndarray: """Current free-parameter values as a 1-D array.""" return np.array([self._values[n] for n in self._free_order], dtype=np.float64) def set_free_vector(self, vec: np.ndarray): """Bulk-update free parameters from a 1-D array.""" for i, name in enumerate(self._free_order): self._values[name] = float(vec[i])