feat(solver): graph decomposition for cluster-by-cluster solving (phase 3)

Add a Python decomposition layer using NetworkX that partitions the
constraint graph into biconnected components (rigid clusters), orders
them via a block-cut tree, and solves each cluster independently.
Articulation-point bodies propagate as boundary conditions between
clusters.

New module kindred_solver/decompose.py:
- DOF table mapping BaseJointKind to residual counts
- Constraint graph construction (nx.MultiGraph)
- Biconnected component detection + articulation points
- Block-cut tree solve ordering (root-first from grounded cluster)
- Cluster-by-cluster solver with boundary body fix/unfix cycling
- Pebble game integration for per-cluster rigidity classification

Changes to existing modules:
- params.py: add unfix() for boundary body cycling
- solver.py: extract _monolithic_solve(), add decomposition branch
  for assemblies with >= 8 free bodies

Performance: for k clusters of ~n/k params each, total cost drops
from O(n^3) to O(n^3/k^2).

220 tests passing (up from 207).
This commit is contained in:
forbes-0023
2026-02-20 22:19:35 -06:00
parent 533ca91774
commit 92ae57751f
5 changed files with 1804 additions and 19 deletions

661
kindred_solver/decompose.py Normal file
View File

@@ -0,0 +1,661 @@
"""Graph decomposition for cluster-by-cluster constraint solving.
Builds a constraint graph from the SolveContext, decomposes it into
biconnected components (rigid clusters), orders them via a block-cut
tree, and solves each cluster independently. Articulation-point bodies
are temporarily fixed when solving adjacent clusters so their solved
values propagate as boundary conditions.
Requires: networkx
"""
from __future__ import annotations
import importlib.util
import logging
import sys
import types as stdlib_types
from collections import deque
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, List
import networkx as nx
from .bfgs import bfgs_solve
from .newton import newton_solve
from .prepass import substitution_pass
if TYPE_CHECKING:
from .constraints import ConstraintBase
from .entities import RigidBody
from .params import ParamTable
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# DOF table: BaseJointKind → number of residuals (= DOF removed)
# ---------------------------------------------------------------------------
# Imported lazily to avoid hard kcsolve dependency in tests.
# Use residual_count() accessor instead of this dict directly.
_RESIDUAL_COUNT: dict[str, int] | None = None
def _ensure_residual_count() -> dict:
"""Build the residual count table on first use."""
global _RESIDUAL_COUNT
if _RESIDUAL_COUNT is not None:
return _RESIDUAL_COUNT
import kcsolve
_RESIDUAL_COUNT = {
kcsolve.BaseJointKind.Fixed: 6,
kcsolve.BaseJointKind.Coincident: 3,
kcsolve.BaseJointKind.Ball: 3,
kcsolve.BaseJointKind.Revolute: 5,
kcsolve.BaseJointKind.Cylindrical: 4,
kcsolve.BaseJointKind.Slider: 5,
kcsolve.BaseJointKind.Screw: 5,
kcsolve.BaseJointKind.Universal: 4,
kcsolve.BaseJointKind.Parallel: 2,
kcsolve.BaseJointKind.Perpendicular: 1,
kcsolve.BaseJointKind.Angle: 1,
kcsolve.BaseJointKind.Concentric: 4,
kcsolve.BaseJointKind.Tangent: 1,
kcsolve.BaseJointKind.Planar: 3,
kcsolve.BaseJointKind.LineInPlane: 2,
kcsolve.BaseJointKind.PointOnLine: 2,
kcsolve.BaseJointKind.PointInPlane: 1,
kcsolve.BaseJointKind.DistancePointPoint: 1,
kcsolve.BaseJointKind.Gear: 1,
kcsolve.BaseJointKind.RackPinion: 1,
kcsolve.BaseJointKind.Cam: 0,
kcsolve.BaseJointKind.Slot: 0,
kcsolve.BaseJointKind.DistanceCylSph: 0,
}
return _RESIDUAL_COUNT
def residual_count(kind) -> int:
"""Number of residuals a constraint type produces."""
return _ensure_residual_count().get(kind, 0)
# ---------------------------------------------------------------------------
# Standalone residual-count table (no kcsolve dependency, string-keyed)
# Used by tests that don't have kcsolve available.
# ---------------------------------------------------------------------------
_RESIDUAL_COUNT_BY_NAME: dict[str, int] = {
"Fixed": 6,
"Coincident": 3,
"Ball": 3,
"Revolute": 5,
"Cylindrical": 4,
"Slider": 5,
"Screw": 5,
"Universal": 4,
"Parallel": 2,
"Perpendicular": 1,
"Angle": 1,
"Concentric": 4,
"Tangent": 1,
"Planar": 3,
"LineInPlane": 2,
"PointOnLine": 2,
"PointInPlane": 1,
"DistancePointPoint": 1,
"Gear": 1,
"RackPinion": 1,
"Cam": 0,
"Slot": 0,
"DistanceCylSph": 0,
}
def residual_count_by_name(kind_name: str) -> int:
"""Number of residuals by constraint type name (no kcsolve needed)."""
return _RESIDUAL_COUNT_BY_NAME.get(kind_name, 0)
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
@dataclass
class SolveCluster:
"""A cluster of bodies to solve together."""
bodies: set[str] # Body IDs in this cluster
constraint_indices: list[int] # Indices into the constraint list
boundary_bodies: set[str] # Articulation points shared with other clusters
has_ground: bool # Whether any body in the cluster is grounded
# ---------------------------------------------------------------------------
# Graph construction
# ---------------------------------------------------------------------------
def build_constraint_graph(
constraints: list,
grounded_bodies: set[str],
) -> nx.MultiGraph:
"""Build a body-level constraint multigraph.
Nodes: part_id strings (one per body referenced by constraints).
Edges: one per active constraint with attributes:
- constraint_index: position in the constraints list
- weight: number of residuals
Grounded bodies are tagged with ``grounded=True``.
Constraints with 0 residuals (stubs) are excluded.
"""
G = nx.MultiGraph()
for idx, c in enumerate(constraints):
if not c.activated:
continue
weight = residual_count(c.type)
if weight == 0:
continue
part_i = c.part_i
part_j = c.part_j
# Ensure nodes exist
if part_i not in G:
G.add_node(part_i, grounded=(part_i in grounded_bodies))
if part_j not in G:
G.add_node(part_j, grounded=(part_j in grounded_bodies))
# Store kind_name for pebble game integration
kind_name = c.type.name if hasattr(c.type, "name") else str(c.type)
G.add_edge(
part_i, part_j, constraint_index=idx, weight=weight, kind_name=kind_name
)
return G
def build_constraint_graph_simple(
edges: list[tuple[str, str, str, int]],
grounded: set[str] | None = None,
) -> nx.MultiGraph:
"""Build a constraint graph from simple edge tuples (for testing).
Each edge is ``(body_i, body_j, kind_name, constraint_index)``.
"""
grounded = grounded or set()
G = nx.MultiGraph()
for body_i, body_j, kind_name, idx in edges:
weight = residual_count_by_name(kind_name)
if weight == 0:
continue
if body_i not in G:
G.add_node(body_i, grounded=(body_i in grounded))
if body_j not in G:
G.add_node(body_j, grounded=(body_j in grounded))
G.add_edge(
body_i, body_j, constraint_index=idx, weight=weight, kind_name=kind_name
)
return G
# ---------------------------------------------------------------------------
# Decomposition
# ---------------------------------------------------------------------------
def find_clusters(
G: nx.MultiGraph,
) -> tuple[list[set[str]], set[str]]:
"""Find biconnected components and articulation points.
Returns:
clusters: list of body-ID sets (one per biconnected component)
articulation_points: body-IDs shared between clusters
"""
# biconnected_components requires a simple Graph
simple = nx.Graph(G)
clusters = [set(c) for c in nx.biconnected_components(simple)]
artic = set(nx.articulation_points(simple))
return clusters, artic
def build_solve_order(
G: nx.MultiGraph,
clusters: list[set[str]],
articulation_points: set[str],
grounded_bodies: set[str],
) -> list[SolveCluster]:
"""Order clusters for solving via the block-cut tree.
Builds the block-cut tree (bipartite graph of clusters and
articulation points), roots it at a grounded cluster, and returns
clusters in root-to-leaf order (grounded first, outward to leaves).
This ensures boundary bodies are solved before clusters that
depend on them.
"""
if not clusters:
return []
# Single cluster — no ordering needed
if len(clusters) == 1:
bodies = clusters[0]
indices = _constraints_for_bodies(G, bodies)
has_ground = bool(bodies & grounded_bodies)
return [
SolveCluster(
bodies=bodies,
constraint_indices=indices,
boundary_bodies=set(),
has_ground=has_ground,
)
]
# Build block-cut tree
# Nodes: ("C", i) for cluster i, ("A", body_id) for articulation points
bct = nx.Graph()
for i, cluster in enumerate(clusters):
bct.add_node(("C", i))
for ap in articulation_points:
if ap in cluster:
bct.add_edge(("C", i), ("A", ap))
# Find root: prefer a cluster containing a grounded body
root = ("C", 0)
for i, cluster in enumerate(clusters):
if cluster & grounded_bodies:
root = ("C", i)
break
# BFS from root: grounded cluster first, outward to leaves
visited = set()
order = []
queue = deque([root])
visited.add(root)
while queue:
node = queue.popleft()
if node[0] == "C":
order.append(node[1])
for neighbor in bct.neighbors(node):
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
# Build SolveCluster objects
solve_clusters = []
for i in order:
bodies = clusters[i]
indices = _constraints_for_bodies(G, bodies)
boundary = bodies & articulation_points
has_ground = bool(bodies & grounded_bodies)
solve_clusters.append(
SolveCluster(
bodies=bodies,
constraint_indices=indices,
boundary_bodies=boundary,
has_ground=has_ground,
)
)
return solve_clusters
def _constraints_for_bodies(G: nx.MultiGraph, bodies: set[str]) -> list[int]:
"""Collect constraint indices for edges where both endpoints are in bodies."""
indices = []
seen = set()
for u, v, data in G.edges(data=True):
idx = data["constraint_index"]
if idx in seen:
continue
if u in bodies and v in bodies:
seen.add(idx)
indices.append(idx)
return sorted(indices)
# ---------------------------------------------------------------------------
# Top-level decompose entry point
# ---------------------------------------------------------------------------
def decompose(
constraints: list,
grounded_bodies: set[str],
) -> list[SolveCluster]:
"""Full decomposition pipeline: graph → clusters → solve order.
Returns a list of SolveCluster in solve order (leaves first).
If the system is a single cluster, returns a 1-element list.
"""
G = build_constraint_graph(constraints, grounded_bodies)
# Handle disconnected sub-assemblies
all_clusters = []
for component_nodes in nx.connected_components(G):
sub = G.subgraph(component_nodes).copy()
clusters, artic = find_clusters(sub)
if len(clusters) <= 1:
# Single cluster in this component
bodies = component_nodes if not clusters else clusters[0]
indices = _constraints_for_bodies(sub, bodies)
has_ground = bool(bodies & grounded_bodies)
all_clusters.append(
SolveCluster(
bodies=set(bodies),
constraint_indices=indices,
boundary_bodies=set(),
has_ground=has_ground,
)
)
else:
ordered = build_solve_order(sub, clusters, artic, grounded_bodies)
all_clusters.extend(ordered)
return all_clusters
# ---------------------------------------------------------------------------
# Cluster solver
# ---------------------------------------------------------------------------
def solve_decomposed(
clusters: list[SolveCluster],
bodies: dict[str, "RigidBody"],
constraint_objs: list["ConstraintBase"],
constraint_indices_map: list[int],
params: "ParamTable",
) -> bool:
"""Solve clusters in order, fixing boundary bodies between solves.
Args:
clusters: SolveCluster list in solve order (from decompose()).
bodies: part_id → RigidBody mapping.
constraint_objs: constraint objects (parallel to constraint_indices_map).
constraint_indices_map: for each constraint_obj, its index in ctx.constraints.
params: shared ParamTable.
Returns True if all clusters converged.
"""
# Build reverse map: constraint_index → position in constraint_objs list
idx_to_obj: dict[int, "ConstraintBase"] = {}
for pos, ci in enumerate(constraint_indices_map):
idx_to_obj[ci] = constraint_objs[pos]
solved_bodies: set[str] = set()
all_converged = True
for cluster in clusters:
# 1. Fix boundary bodies that were already solved
fixed_boundary_params: list[str] = []
for body_id in cluster.boundary_bodies:
if body_id in solved_bodies:
body = bodies[body_id]
for pname in body._param_names:
if not params.is_fixed(pname):
params.fix(pname)
fixed_boundary_params.append(pname)
# 2. Collect residuals for this cluster
cluster_residuals = []
for ci in cluster.constraint_indices:
obj = idx_to_obj.get(ci)
if obj is not None:
cluster_residuals.extend(obj.residuals())
# 3. Add quat norm residuals for free, non-grounded bodies in this cluster
quat_groups = []
for body_id in cluster.bodies:
body = bodies[body_id]
if body.grounded:
continue
if body_id in cluster.boundary_bodies and body_id in solved_bodies:
continue # Already fixed as boundary
cluster_residuals.append(body.quat_norm_residual())
quat_groups.append(body.quat_param_names())
# 4. Substitution pass (compiles fixed boundary params to constants)
cluster_residuals = substitution_pass(cluster_residuals, params)
# 5. Newton solve (+ BFGS fallback)
if cluster_residuals:
converged = newton_solve(
cluster_residuals,
params,
quat_groups=quat_groups,
max_iter=100,
tol=1e-10,
)
if not converged:
converged = bfgs_solve(
cluster_residuals,
params,
quat_groups=quat_groups,
max_iter=200,
tol=1e-10,
)
if not converged:
all_converged = False
# 6. Mark this cluster's bodies as solved
solved_bodies.update(cluster.bodies)
# 7. Unfix boundary params
for pname in fixed_boundary_params:
params.unfix(pname)
return all_converged
# ---------------------------------------------------------------------------
# Pebble game integration (rigidity classification)
# ---------------------------------------------------------------------------
_PEBBLE_MODULES_LOADED = False
_PebbleGame3D = None
_PebbleJointType = None
_PebbleJoint = None
def _load_pebble_modules():
"""Lazily load PebbleGame3D and related types from GNN/solver/datagen/.
The GNN package has its own import structure (``from solver.datagen.types
import ...``) that conflicts with the top-level module layout, so we
register shim modules in ``sys.modules`` to make it work.
"""
global _PEBBLE_MODULES_LOADED, _PebbleGame3D, _PebbleJointType, _PebbleJoint
if _PEBBLE_MODULES_LOADED:
return
# Find GNN/solver/datagen relative to this package
pkg_dir = Path(__file__).resolve().parent.parent # mods/solver/
datagen_dir = pkg_dir / "GNN" / "solver" / "datagen"
if not datagen_dir.exists():
log.warning("GNN/solver/datagen/ not found; pebble game unavailable")
_PEBBLE_MODULES_LOADED = True
return
# Register shim modules so ``from solver.datagen.types import ...`` works
if "solver" not in sys.modules:
sys.modules["solver"] = stdlib_types.ModuleType("solver")
if "solver.datagen" not in sys.modules:
dg = stdlib_types.ModuleType("solver.datagen")
sys.modules["solver.datagen"] = dg
sys.modules["solver"].datagen = dg # type: ignore[attr-defined]
# Load types.py
types_path = datagen_dir / "types.py"
spec_t = importlib.util.spec_from_file_location(
"solver.datagen.types", str(types_path)
)
types_mod = importlib.util.module_from_spec(spec_t)
sys.modules["solver.datagen.types"] = types_mod
spec_t.loader.exec_module(types_mod)
# Load pebble_game.py
pg_path = datagen_dir / "pebble_game.py"
spec_p = importlib.util.spec_from_file_location(
"solver.datagen.pebble_game", str(pg_path)
)
pg_mod = importlib.util.module_from_spec(spec_p)
sys.modules["solver.datagen.pebble_game"] = pg_mod
spec_p.loader.exec_module(pg_mod)
_PebbleGame3D = pg_mod.PebbleGame3D
_PebbleJointType = types_mod.JointType
_PebbleJoint = types_mod.Joint
_PEBBLE_MODULES_LOADED = True
# BaseJointKind name → PebbleGame JointType name.
# Types not listed here use manual edge insertion with the residual count.
_KIND_NAME_TO_PEBBLE_NAME: dict[str, str] = {
"Fixed": "FIXED",
"Coincident": "BALL", # Same DOF count (3)
"Ball": "BALL",
"Revolute": "REVOLUTE",
"Cylindrical": "CYLINDRICAL",
"Slider": "SLIDER",
"Screw": "SCREW",
"Universal": "UNIVERSAL",
"Planar": "PLANAR",
"Perpendicular": "PERPENDICULAR",
"DistancePointPoint": "DISTANCE",
}
# Parallel: pebble game uses 3 DOF, but our solver uses 2.
# We handle it with manual edge insertion.
# Types that need manual edge insertion (no direct JointType mapping,
# or DOF mismatch like Parallel).
_MANUAL_EDGE_TYPES: set[str] = {
"Parallel", # 2 residuals, but JointType.PARALLEL = 3
"Angle", # 1 residual, no JointType
"Concentric", # 4 residuals, no JointType
"Tangent", # 1 residual, no JointType
"LineInPlane", # 2 residuals, no JointType
"PointOnLine", # 2 residuals, no JointType
"PointInPlane", # 1 residual, no JointType
"Gear", # 1 residual, no JointType
"RackPinion", # 1 residual, no JointType
}
_GROUND_BODY_ID = -1
def classify_cluster_rigidity(
cluster: SolveCluster,
constraint_graph: nx.MultiGraph,
grounded_bodies: set[str],
) -> str | None:
"""Run pebble game on a cluster and return rigidity classification.
Returns one of: "well-constrained", "underconstrained",
"overconstrained", "mixed", or None if pebble game unavailable.
"""
import numpy as np
_load_pebble_modules()
if _PebbleGame3D is None:
return None
pg = _PebbleGame3D()
# Map string body IDs → integer IDs for pebble game
body_list = sorted(cluster.bodies)
body_to_int: dict[str, int] = {b: i for i, b in enumerate(body_list)}
for b in body_list:
pg.add_body(body_to_int[b])
# Add virtual ground body if cluster has grounded bodies
has_ground = bool(cluster.bodies & grounded_bodies)
if has_ground:
pg.add_body(_GROUND_BODY_ID)
for b in cluster.bodies & grounded_bodies:
ground_joint = _PebbleJoint(
joint_id=-1,
body_a=body_to_int[b],
body_b=_GROUND_BODY_ID,
joint_type=_PebbleJointType["FIXED"],
anchor_a=np.zeros(3),
anchor_b=np.zeros(3),
)
pg.add_joint(ground_joint)
# Add constraint edges
joint_counter = 0
zero = np.zeros(3)
for u, v, data in constraint_graph.edges(data=True):
if u not in cluster.bodies or v not in cluster.bodies:
continue
ci = data["constraint_index"]
if ci not in cluster.constraint_indices:
continue
# Determine the constraint kind name from the graph edge
kind_name = data.get("kind_name", "")
n_residuals = data.get("weight", 0)
if not kind_name or n_residuals == 0:
continue
int_u = body_to_int[u]
int_v = body_to_int[v]
pebble_name = _KIND_NAME_TO_PEBBLE_NAME.get(kind_name)
if pebble_name and kind_name not in _MANUAL_EDGE_TYPES:
# Direct JointType mapping
jt = _PebbleJointType[pebble_name]
joint = _PebbleJoint(
joint_id=joint_counter,
body_a=int_u,
body_b=int_v,
joint_type=jt,
anchor_a=zero,
anchor_b=zero,
)
pg.add_joint(joint)
joint_counter += 1
else:
# Manual edge insertion: one DISTANCE edge per residual
for _ in range(n_residuals):
joint = _PebbleJoint(
joint_id=joint_counter,
body_a=int_u,
body_b=int_v,
joint_type=_PebbleJointType["DISTANCE"],
anchor_a=zero,
anchor_b=zero,
)
pg.add_joint(joint)
joint_counter += 1
# Classify using raw pebble counts (adjusting for virtual ground)
total_dof = pg.get_dof()
redundant = pg.get_redundant_count()
# The virtual ground body contributes 6 pebbles that are never consumed.
# Subtract them to get the effective DOF.
if has_ground:
total_dof -= 6 # virtual ground's unconstrained pebbles
baseline = 0
else:
baseline = 6 # trivial rigid-body motion
if redundant > 0 and total_dof > baseline:
return "mixed"
elif redundant > 0:
return "overconstrained"
elif total_dof > baseline:
return "underconstrained"
elif total_dof == baseline:
return "well-constrained"
else:
return "overconstrained"

View File

@@ -49,6 +49,13 @@ class ParamTable:
if name in self._free_order:
self._free_order.remove(name)
def unfix(self, name: str):
"""Restore a fixed parameter to free status."""
if name in self._fixed:
self._fixed.discard(name)
if name not in self._free_order:
self._free_order.append(name)
def get_env(self) -> Dict[str, float]:
"""Return a snapshot of all current values (for Expr.eval)."""
return dict(self._values)

View File

@@ -32,12 +32,16 @@ from .constraints import (
TangentConstraint,
UniversalConstraint,
)
from .decompose import decompose, solve_decomposed
from .dof import count_dof
from .entities import RigidBody
from .newton import newton_solve
from .params import ParamTable
from .prepass import single_equation_pass, substitution_pass
# Assemblies with fewer free bodies than this use the monolithic path.
_DECOMPOSE_THRESHOLD = 8
# All BaseJointKind values this solver can handle
_SUPPORTED = {
# Phase 1
@@ -95,11 +99,12 @@ class KindredSolver(kcsolve.IKCSolver):
)
bodies[part.id] = body
# 2. Build constraint residuals
# 2. Build constraint residuals (track index mapping for decomposition)
all_residuals = []
constraint_objs = []
constraint_indices = [] # parallel to constraint_objs: index in ctx.constraints
for c in ctx.constraints:
for idx, c in enumerate(ctx.constraints):
if not c.activated:
continue
body_i = bodies.get(c.part_i)
@@ -123,6 +128,7 @@ class KindredSolver(kcsolve.IKCSolver):
if obj is None:
continue
constraint_objs.append(obj)
constraint_indices.append(idx)
all_residuals.extend(obj.residuals())
# 3. Add quaternion normalization residuals for non-grounded bodies
@@ -132,26 +138,31 @@ class KindredSolver(kcsolve.IKCSolver):
all_residuals.append(body.quat_norm_residual())
quat_groups.append(body.quat_param_names())
# 4. Pre-passes
# 4. Pre-passes on full system
all_residuals = substitution_pass(all_residuals, params)
all_residuals = single_equation_pass(all_residuals, params)
# 5. Newton-Raphson (with BFGS fallback)
converged = newton_solve(
# 5. Solve (decomposed for large assemblies, monolithic for small)
n_free_bodies = sum(1 for b in bodies.values() if not b.grounded)
if n_free_bodies >= _DECOMPOSE_THRESHOLD:
grounded_ids = {pid for pid, b in bodies.items() if b.grounded}
clusters = decompose(ctx.constraints, grounded_ids)
if len(clusters) > 1:
converged = solve_decomposed(
clusters,
bodies,
constraint_objs,
constraint_indices,
params,
)
else:
converged = _monolithic_solve(
all_residuals,
params,
quat_groups=quat_groups,
max_iter=100,
tol=1e-10,
)
if not converged:
converged = bfgs_solve(
all_residuals,
params,
quat_groups=quat_groups,
max_iter=200,
tol=1e-10,
quat_groups,
)
else:
converged = _monolithic_solve(all_residuals, params, quat_groups)
# 6. DOF
dof = count_dof(all_residuals, params)
@@ -182,6 +193,26 @@ class KindredSolver(kcsolve.IKCSolver):
return True
def _monolithic_solve(all_residuals, params, quat_groups):
"""Newton-Raphson solve with BFGS fallback on the full system."""
converged = newton_solve(
all_residuals,
params,
quat_groups=quat_groups,
max_iter=100,
tol=1e-10,
)
if not converged:
converged = bfgs_solve(
all_residuals,
params,
quat_groups=quat_groups,
max_iter=200,
tol=1e-10,
)
return converged
def _build_constraint(
kind,
body_i,

1052
tests/test_decompose.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -65,3 +65,37 @@ class TestParamTable:
pt.add("b", 0.0, fixed=True)
pt.add("c", 0.0)
assert pt.n_free() == 2
def test_unfix(self):
pt = ParamTable()
pt.add("a", 1.0)
pt.add("b", 2.0)
pt.fix("a")
assert pt.is_fixed("a")
assert "a" not in pt.free_names()
pt.unfix("a")
assert not pt.is_fixed("a")
assert "a" in pt.free_names()
assert pt.n_free() == 2
def test_fix_unfix_roundtrip(self):
"""Fix then unfix preserves value and makes param free again."""
pt = ParamTable()
pt.add("x", 5.0)
pt.add("y", 3.0)
pt.fix("x")
pt.set_value("x", 10.0)
pt.unfix("x")
assert pt.get_value("x") == 10.0
assert "x" in pt.free_names()
# x moves to end of free list
assert pt.free_names() == ["y", "x"]
def test_unfix_noop_if_already_free(self):
"""Unfixing a free parameter is a no-op."""
pt = ParamTable()
pt.add("a", 1.0)
pt.unfix("a")
assert pt.free_names() == ["a"]
assert pt.n_free() == 1