From 92ae57751f97fd24d978ea19084968a99f1d7d6d Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Fri, 20 Feb 2026 22:19:35 -0600 Subject: [PATCH] feat(solver): graph decomposition for cluster-by-cluster solving (phase 3) Add a Python decomposition layer using NetworkX that partitions the constraint graph into biconnected components (rigid clusters), orders them via a block-cut tree, and solves each cluster independently. Articulation-point bodies propagate as boundary conditions between clusters. New module kindred_solver/decompose.py: - DOF table mapping BaseJointKind to residual counts - Constraint graph construction (nx.MultiGraph) - Biconnected component detection + articulation points - Block-cut tree solve ordering (root-first from grounded cluster) - Cluster-by-cluster solver with boundary body fix/unfix cycling - Pebble game integration for per-cluster rigidity classification Changes to existing modules: - params.py: add unfix() for boundary body cycling - solver.py: extract _monolithic_solve(), add decomposition branch for assemblies with >= 8 free bodies Performance: for k clusters of ~n/k params each, total cost drops from O(n^3) to O(n^3/k^2). 220 tests passing (up from 207). --- kindred_solver/decompose.py | 661 ++++++++++++++++++++++ kindred_solver/params.py | 7 + kindred_solver/solver.py | 69 ++- tests/test_decompose.py | 1052 +++++++++++++++++++++++++++++++++++ tests/test_params.py | 34 ++ 5 files changed, 1804 insertions(+), 19 deletions(-) create mode 100644 kindred_solver/decompose.py create mode 100644 tests/test_decompose.py diff --git a/kindred_solver/decompose.py b/kindred_solver/decompose.py new file mode 100644 index 0000000..6b5e581 --- /dev/null +++ b/kindred_solver/decompose.py @@ -0,0 +1,661 @@ +"""Graph decomposition for cluster-by-cluster constraint solving. + +Builds a constraint graph from the SolveContext, decomposes it into +biconnected components (rigid clusters), orders them via a block-cut +tree, and solves each cluster independently. Articulation-point bodies +are temporarily fixed when solving adjacent clusters so their solved +values propagate as boundary conditions. + +Requires: networkx +""" + +from __future__ import annotations + +import importlib.util +import logging +import sys +import types as stdlib_types +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, List + +import networkx as nx + +from .bfgs import bfgs_solve +from .newton import newton_solve +from .prepass import substitution_pass + +if TYPE_CHECKING: + from .constraints import ConstraintBase + from .entities import RigidBody + from .params import ParamTable + +log = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# DOF table: BaseJointKind → number of residuals (= DOF removed) +# --------------------------------------------------------------------------- + +# Imported lazily to avoid hard kcsolve dependency in tests. +# Use residual_count() accessor instead of this dict directly. +_RESIDUAL_COUNT: dict[str, int] | None = None + + +def _ensure_residual_count() -> dict: + """Build the residual count table on first use.""" + global _RESIDUAL_COUNT + if _RESIDUAL_COUNT is not None: + return _RESIDUAL_COUNT + + import kcsolve + + _RESIDUAL_COUNT = { + kcsolve.BaseJointKind.Fixed: 6, + kcsolve.BaseJointKind.Coincident: 3, + kcsolve.BaseJointKind.Ball: 3, + kcsolve.BaseJointKind.Revolute: 5, + kcsolve.BaseJointKind.Cylindrical: 4, + kcsolve.BaseJointKind.Slider: 5, + kcsolve.BaseJointKind.Screw: 5, + kcsolve.BaseJointKind.Universal: 4, + kcsolve.BaseJointKind.Parallel: 2, + kcsolve.BaseJointKind.Perpendicular: 1, + kcsolve.BaseJointKind.Angle: 1, + kcsolve.BaseJointKind.Concentric: 4, + kcsolve.BaseJointKind.Tangent: 1, + kcsolve.BaseJointKind.Planar: 3, + kcsolve.BaseJointKind.LineInPlane: 2, + kcsolve.BaseJointKind.PointOnLine: 2, + kcsolve.BaseJointKind.PointInPlane: 1, + kcsolve.BaseJointKind.DistancePointPoint: 1, + kcsolve.BaseJointKind.Gear: 1, + kcsolve.BaseJointKind.RackPinion: 1, + kcsolve.BaseJointKind.Cam: 0, + kcsolve.BaseJointKind.Slot: 0, + kcsolve.BaseJointKind.DistanceCylSph: 0, + } + return _RESIDUAL_COUNT + + +def residual_count(kind) -> int: + """Number of residuals a constraint type produces.""" + return _ensure_residual_count().get(kind, 0) + + +# --------------------------------------------------------------------------- +# Standalone residual-count table (no kcsolve dependency, string-keyed) +# Used by tests that don't have kcsolve available. +# --------------------------------------------------------------------------- + +_RESIDUAL_COUNT_BY_NAME: dict[str, int] = { + "Fixed": 6, + "Coincident": 3, + "Ball": 3, + "Revolute": 5, + "Cylindrical": 4, + "Slider": 5, + "Screw": 5, + "Universal": 4, + "Parallel": 2, + "Perpendicular": 1, + "Angle": 1, + "Concentric": 4, + "Tangent": 1, + "Planar": 3, + "LineInPlane": 2, + "PointOnLine": 2, + "PointInPlane": 1, + "DistancePointPoint": 1, + "Gear": 1, + "RackPinion": 1, + "Cam": 0, + "Slot": 0, + "DistanceCylSph": 0, +} + + +def residual_count_by_name(kind_name: str) -> int: + """Number of residuals by constraint type name (no kcsolve needed).""" + return _RESIDUAL_COUNT_BY_NAME.get(kind_name, 0) + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + + +@dataclass +class SolveCluster: + """A cluster of bodies to solve together.""" + + bodies: set[str] # Body IDs in this cluster + constraint_indices: list[int] # Indices into the constraint list + boundary_bodies: set[str] # Articulation points shared with other clusters + has_ground: bool # Whether any body in the cluster is grounded + + +# --------------------------------------------------------------------------- +# Graph construction +# --------------------------------------------------------------------------- + + +def build_constraint_graph( + constraints: list, + grounded_bodies: set[str], +) -> nx.MultiGraph: + """Build a body-level constraint multigraph. + + Nodes: part_id strings (one per body referenced by constraints). + Edges: one per active constraint with attributes: + - constraint_index: position in the constraints list + - weight: number of residuals + + Grounded bodies are tagged with ``grounded=True``. + Constraints with 0 residuals (stubs) are excluded. + """ + G = nx.MultiGraph() + + for idx, c in enumerate(constraints): + if not c.activated: + continue + weight = residual_count(c.type) + if weight == 0: + continue + part_i = c.part_i + part_j = c.part_j + # Ensure nodes exist + if part_i not in G: + G.add_node(part_i, grounded=(part_i in grounded_bodies)) + if part_j not in G: + G.add_node(part_j, grounded=(part_j in grounded_bodies)) + # Store kind_name for pebble game integration + kind_name = c.type.name if hasattr(c.type, "name") else str(c.type) + G.add_edge( + part_i, part_j, constraint_index=idx, weight=weight, kind_name=kind_name + ) + + return G + + +def build_constraint_graph_simple( + edges: list[tuple[str, str, str, int]], + grounded: set[str] | None = None, +) -> nx.MultiGraph: + """Build a constraint graph from simple edge tuples (for testing). + + Each edge is ``(body_i, body_j, kind_name, constraint_index)``. + """ + grounded = grounded or set() + G = nx.MultiGraph() + for body_i, body_j, kind_name, idx in edges: + weight = residual_count_by_name(kind_name) + if weight == 0: + continue + if body_i not in G: + G.add_node(body_i, grounded=(body_i in grounded)) + if body_j not in G: + G.add_node(body_j, grounded=(body_j in grounded)) + G.add_edge( + body_i, body_j, constraint_index=idx, weight=weight, kind_name=kind_name + ) + return G + + +# --------------------------------------------------------------------------- +# Decomposition +# --------------------------------------------------------------------------- + + +def find_clusters( + G: nx.MultiGraph, +) -> tuple[list[set[str]], set[str]]: + """Find biconnected components and articulation points. + + Returns: + clusters: list of body-ID sets (one per biconnected component) + articulation_points: body-IDs shared between clusters + """ + # biconnected_components requires a simple Graph + simple = nx.Graph(G) + clusters = [set(c) for c in nx.biconnected_components(simple)] + artic = set(nx.articulation_points(simple)) + return clusters, artic + + +def build_solve_order( + G: nx.MultiGraph, + clusters: list[set[str]], + articulation_points: set[str], + grounded_bodies: set[str], +) -> list[SolveCluster]: + """Order clusters for solving via the block-cut tree. + + Builds the block-cut tree (bipartite graph of clusters and + articulation points), roots it at a grounded cluster, and returns + clusters in root-to-leaf order (grounded first, outward to leaves). + This ensures boundary bodies are solved before clusters that + depend on them. + """ + if not clusters: + return [] + + # Single cluster — no ordering needed + if len(clusters) == 1: + bodies = clusters[0] + indices = _constraints_for_bodies(G, bodies) + has_ground = bool(bodies & grounded_bodies) + return [ + SolveCluster( + bodies=bodies, + constraint_indices=indices, + boundary_bodies=set(), + has_ground=has_ground, + ) + ] + + # Build block-cut tree + # Nodes: ("C", i) for cluster i, ("A", body_id) for articulation points + bct = nx.Graph() + for i, cluster in enumerate(clusters): + bct.add_node(("C", i)) + for ap in articulation_points: + if ap in cluster: + bct.add_edge(("C", i), ("A", ap)) + + # Find root: prefer a cluster containing a grounded body + root = ("C", 0) + for i, cluster in enumerate(clusters): + if cluster & grounded_bodies: + root = ("C", i) + break + + # BFS from root: grounded cluster first, outward to leaves + visited = set() + order = [] + queue = deque([root]) + visited.add(root) + while queue: + node = queue.popleft() + if node[0] == "C": + order.append(node[1]) + for neighbor in bct.neighbors(node): + if neighbor not in visited: + visited.add(neighbor) + queue.append(neighbor) + + # Build SolveCluster objects + solve_clusters = [] + for i in order: + bodies = clusters[i] + indices = _constraints_for_bodies(G, bodies) + boundary = bodies & articulation_points + has_ground = bool(bodies & grounded_bodies) + solve_clusters.append( + SolveCluster( + bodies=bodies, + constraint_indices=indices, + boundary_bodies=boundary, + has_ground=has_ground, + ) + ) + + return solve_clusters + + +def _constraints_for_bodies(G: nx.MultiGraph, bodies: set[str]) -> list[int]: + """Collect constraint indices for edges where both endpoints are in bodies.""" + indices = [] + seen = set() + for u, v, data in G.edges(data=True): + idx = data["constraint_index"] + if idx in seen: + continue + if u in bodies and v in bodies: + seen.add(idx) + indices.append(idx) + return sorted(indices) + + +# --------------------------------------------------------------------------- +# Top-level decompose entry point +# --------------------------------------------------------------------------- + + +def decompose( + constraints: list, + grounded_bodies: set[str], +) -> list[SolveCluster]: + """Full decomposition pipeline: graph → clusters → solve order. + + Returns a list of SolveCluster in solve order (leaves first). + If the system is a single cluster, returns a 1-element list. + """ + G = build_constraint_graph(constraints, grounded_bodies) + + # Handle disconnected sub-assemblies + all_clusters = [] + for component_nodes in nx.connected_components(G): + sub = G.subgraph(component_nodes).copy() + clusters, artic = find_clusters(sub) + + if len(clusters) <= 1: + # Single cluster in this component + bodies = component_nodes if not clusters else clusters[0] + indices = _constraints_for_bodies(sub, bodies) + has_ground = bool(bodies & grounded_bodies) + all_clusters.append( + SolveCluster( + bodies=set(bodies), + constraint_indices=indices, + boundary_bodies=set(), + has_ground=has_ground, + ) + ) + else: + ordered = build_solve_order(sub, clusters, artic, grounded_bodies) + all_clusters.extend(ordered) + + return all_clusters + + +# --------------------------------------------------------------------------- +# Cluster solver +# --------------------------------------------------------------------------- + + +def solve_decomposed( + clusters: list[SolveCluster], + bodies: dict[str, "RigidBody"], + constraint_objs: list["ConstraintBase"], + constraint_indices_map: list[int], + params: "ParamTable", +) -> bool: + """Solve clusters in order, fixing boundary bodies between solves. + + Args: + clusters: SolveCluster list in solve order (from decompose()). + bodies: part_id → RigidBody mapping. + constraint_objs: constraint objects (parallel to constraint_indices_map). + constraint_indices_map: for each constraint_obj, its index in ctx.constraints. + params: shared ParamTable. + + Returns True if all clusters converged. + """ + # Build reverse map: constraint_index → position in constraint_objs list + idx_to_obj: dict[int, "ConstraintBase"] = {} + for pos, ci in enumerate(constraint_indices_map): + idx_to_obj[ci] = constraint_objs[pos] + + solved_bodies: set[str] = set() + all_converged = True + + for cluster in clusters: + # 1. Fix boundary bodies that were already solved + fixed_boundary_params: list[str] = [] + for body_id in cluster.boundary_bodies: + if body_id in solved_bodies: + body = bodies[body_id] + for pname in body._param_names: + if not params.is_fixed(pname): + params.fix(pname) + fixed_boundary_params.append(pname) + + # 2. Collect residuals for this cluster + cluster_residuals = [] + for ci in cluster.constraint_indices: + obj = idx_to_obj.get(ci) + if obj is not None: + cluster_residuals.extend(obj.residuals()) + + # 3. Add quat norm residuals for free, non-grounded bodies in this cluster + quat_groups = [] + for body_id in cluster.bodies: + body = bodies[body_id] + if body.grounded: + continue + if body_id in cluster.boundary_bodies and body_id in solved_bodies: + continue # Already fixed as boundary + cluster_residuals.append(body.quat_norm_residual()) + quat_groups.append(body.quat_param_names()) + + # 4. Substitution pass (compiles fixed boundary params to constants) + cluster_residuals = substitution_pass(cluster_residuals, params) + + # 5. Newton solve (+ BFGS fallback) + if cluster_residuals: + converged = newton_solve( + cluster_residuals, + params, + quat_groups=quat_groups, + max_iter=100, + tol=1e-10, + ) + if not converged: + converged = bfgs_solve( + cluster_residuals, + params, + quat_groups=quat_groups, + max_iter=200, + tol=1e-10, + ) + if not converged: + all_converged = False + + # 6. Mark this cluster's bodies as solved + solved_bodies.update(cluster.bodies) + + # 7. Unfix boundary params + for pname in fixed_boundary_params: + params.unfix(pname) + + return all_converged + + +# --------------------------------------------------------------------------- +# Pebble game integration (rigidity classification) +# --------------------------------------------------------------------------- + +_PEBBLE_MODULES_LOADED = False +_PebbleGame3D = None +_PebbleJointType = None +_PebbleJoint = None + + +def _load_pebble_modules(): + """Lazily load PebbleGame3D and related types from GNN/solver/datagen/. + + The GNN package has its own import structure (``from solver.datagen.types + import ...``) that conflicts with the top-level module layout, so we + register shim modules in ``sys.modules`` to make it work. + """ + global _PEBBLE_MODULES_LOADED, _PebbleGame3D, _PebbleJointType, _PebbleJoint + if _PEBBLE_MODULES_LOADED: + return + + # Find GNN/solver/datagen relative to this package + pkg_dir = Path(__file__).resolve().parent.parent # mods/solver/ + datagen_dir = pkg_dir / "GNN" / "solver" / "datagen" + + if not datagen_dir.exists(): + log.warning("GNN/solver/datagen/ not found; pebble game unavailable") + _PEBBLE_MODULES_LOADED = True + return + + # Register shim modules so ``from solver.datagen.types import ...`` works + if "solver" not in sys.modules: + sys.modules["solver"] = stdlib_types.ModuleType("solver") + if "solver.datagen" not in sys.modules: + dg = stdlib_types.ModuleType("solver.datagen") + sys.modules["solver.datagen"] = dg + sys.modules["solver"].datagen = dg # type: ignore[attr-defined] + + # Load types.py + types_path = datagen_dir / "types.py" + spec_t = importlib.util.spec_from_file_location( + "solver.datagen.types", str(types_path) + ) + types_mod = importlib.util.module_from_spec(spec_t) + sys.modules["solver.datagen.types"] = types_mod + spec_t.loader.exec_module(types_mod) + + # Load pebble_game.py + pg_path = datagen_dir / "pebble_game.py" + spec_p = importlib.util.spec_from_file_location( + "solver.datagen.pebble_game", str(pg_path) + ) + pg_mod = importlib.util.module_from_spec(spec_p) + sys.modules["solver.datagen.pebble_game"] = pg_mod + spec_p.loader.exec_module(pg_mod) + + _PebbleGame3D = pg_mod.PebbleGame3D + _PebbleJointType = types_mod.JointType + _PebbleJoint = types_mod.Joint + _PEBBLE_MODULES_LOADED = True + + +# BaseJointKind name → PebbleGame JointType name. +# Types not listed here use manual edge insertion with the residual count. +_KIND_NAME_TO_PEBBLE_NAME: dict[str, str] = { + "Fixed": "FIXED", + "Coincident": "BALL", # Same DOF count (3) + "Ball": "BALL", + "Revolute": "REVOLUTE", + "Cylindrical": "CYLINDRICAL", + "Slider": "SLIDER", + "Screw": "SCREW", + "Universal": "UNIVERSAL", + "Planar": "PLANAR", + "Perpendicular": "PERPENDICULAR", + "DistancePointPoint": "DISTANCE", +} +# Parallel: pebble game uses 3 DOF, but our solver uses 2. +# We handle it with manual edge insertion. + +# Types that need manual edge insertion (no direct JointType mapping, +# or DOF mismatch like Parallel). +_MANUAL_EDGE_TYPES: set[str] = { + "Parallel", # 2 residuals, but JointType.PARALLEL = 3 + "Angle", # 1 residual, no JointType + "Concentric", # 4 residuals, no JointType + "Tangent", # 1 residual, no JointType + "LineInPlane", # 2 residuals, no JointType + "PointOnLine", # 2 residuals, no JointType + "PointInPlane", # 1 residual, no JointType + "Gear", # 1 residual, no JointType + "RackPinion", # 1 residual, no JointType +} + +_GROUND_BODY_ID = -1 + + +def classify_cluster_rigidity( + cluster: SolveCluster, + constraint_graph: nx.MultiGraph, + grounded_bodies: set[str], +) -> str | None: + """Run pebble game on a cluster and return rigidity classification. + + Returns one of: "well-constrained", "underconstrained", + "overconstrained", "mixed", or None if pebble game unavailable. + """ + import numpy as np + + _load_pebble_modules() + if _PebbleGame3D is None: + return None + + pg = _PebbleGame3D() + + # Map string body IDs → integer IDs for pebble game + body_list = sorted(cluster.bodies) + body_to_int: dict[str, int] = {b: i for i, b in enumerate(body_list)} + + for b in body_list: + pg.add_body(body_to_int[b]) + + # Add virtual ground body if cluster has grounded bodies + has_ground = bool(cluster.bodies & grounded_bodies) + if has_ground: + pg.add_body(_GROUND_BODY_ID) + for b in cluster.bodies & grounded_bodies: + ground_joint = _PebbleJoint( + joint_id=-1, + body_a=body_to_int[b], + body_b=_GROUND_BODY_ID, + joint_type=_PebbleJointType["FIXED"], + anchor_a=np.zeros(3), + anchor_b=np.zeros(3), + ) + pg.add_joint(ground_joint) + + # Add constraint edges + joint_counter = 0 + zero = np.zeros(3) + for u, v, data in constraint_graph.edges(data=True): + if u not in cluster.bodies or v not in cluster.bodies: + continue + ci = data["constraint_index"] + if ci not in cluster.constraint_indices: + continue + + # Determine the constraint kind name from the graph edge + kind_name = data.get("kind_name", "") + n_residuals = data.get("weight", 0) + + if not kind_name or n_residuals == 0: + continue + + int_u = body_to_int[u] + int_v = body_to_int[v] + + pebble_name = _KIND_NAME_TO_PEBBLE_NAME.get(kind_name) + if pebble_name and kind_name not in _MANUAL_EDGE_TYPES: + # Direct JointType mapping + jt = _PebbleJointType[pebble_name] + joint = _PebbleJoint( + joint_id=joint_counter, + body_a=int_u, + body_b=int_v, + joint_type=jt, + anchor_a=zero, + anchor_b=zero, + ) + pg.add_joint(joint) + joint_counter += 1 + else: + # Manual edge insertion: one DISTANCE edge per residual + for _ in range(n_residuals): + joint = _PebbleJoint( + joint_id=joint_counter, + body_a=int_u, + body_b=int_v, + joint_type=_PebbleJointType["DISTANCE"], + anchor_a=zero, + anchor_b=zero, + ) + pg.add_joint(joint) + joint_counter += 1 + + # Classify using raw pebble counts (adjusting for virtual ground) + total_dof = pg.get_dof() + redundant = pg.get_redundant_count() + + # The virtual ground body contributes 6 pebbles that are never consumed. + # Subtract them to get the effective DOF. + if has_ground: + total_dof -= 6 # virtual ground's unconstrained pebbles + baseline = 0 + else: + baseline = 6 # trivial rigid-body motion + + if redundant > 0 and total_dof > baseline: + return "mixed" + elif redundant > 0: + return "overconstrained" + elif total_dof > baseline: + return "underconstrained" + elif total_dof == baseline: + return "well-constrained" + else: + return "overconstrained" diff --git a/kindred_solver/params.py b/kindred_solver/params.py index 40a66a9..a5b7c36 100644 --- a/kindred_solver/params.py +++ b/kindred_solver/params.py @@ -49,6 +49,13 @@ class ParamTable: if name in self._free_order: self._free_order.remove(name) + def unfix(self, name: str): + """Restore a fixed parameter to free status.""" + if name in self._fixed: + self._fixed.discard(name) + if name not in self._free_order: + self._free_order.append(name) + def get_env(self) -> Dict[str, float]: """Return a snapshot of all current values (for Expr.eval).""" return dict(self._values) diff --git a/kindred_solver/solver.py b/kindred_solver/solver.py index bab8c8a..37bc908 100644 --- a/kindred_solver/solver.py +++ b/kindred_solver/solver.py @@ -32,12 +32,16 @@ from .constraints import ( TangentConstraint, UniversalConstraint, ) +from .decompose import decompose, solve_decomposed from .dof import count_dof from .entities import RigidBody from .newton import newton_solve from .params import ParamTable from .prepass import single_equation_pass, substitution_pass +# Assemblies with fewer free bodies than this use the monolithic path. +_DECOMPOSE_THRESHOLD = 8 + # All BaseJointKind values this solver can handle _SUPPORTED = { # Phase 1 @@ -95,11 +99,12 @@ class KindredSolver(kcsolve.IKCSolver): ) bodies[part.id] = body - # 2. Build constraint residuals + # 2. Build constraint residuals (track index mapping for decomposition) all_residuals = [] constraint_objs = [] + constraint_indices = [] # parallel to constraint_objs: index in ctx.constraints - for c in ctx.constraints: + for idx, c in enumerate(ctx.constraints): if not c.activated: continue body_i = bodies.get(c.part_i) @@ -123,6 +128,7 @@ class KindredSolver(kcsolve.IKCSolver): if obj is None: continue constraint_objs.append(obj) + constraint_indices.append(idx) all_residuals.extend(obj.residuals()) # 3. Add quaternion normalization residuals for non-grounded bodies @@ -132,26 +138,31 @@ class KindredSolver(kcsolve.IKCSolver): all_residuals.append(body.quat_norm_residual()) quat_groups.append(body.quat_param_names()) - # 4. Pre-passes + # 4. Pre-passes on full system all_residuals = substitution_pass(all_residuals, params) all_residuals = single_equation_pass(all_residuals, params) - # 5. Newton-Raphson (with BFGS fallback) - converged = newton_solve( - all_residuals, - params, - quat_groups=quat_groups, - max_iter=100, - tol=1e-10, - ) - if not converged: - converged = bfgs_solve( - all_residuals, - params, - quat_groups=quat_groups, - max_iter=200, - tol=1e-10, - ) + # 5. Solve (decomposed for large assemblies, monolithic for small) + n_free_bodies = sum(1 for b in bodies.values() if not b.grounded) + if n_free_bodies >= _DECOMPOSE_THRESHOLD: + grounded_ids = {pid for pid, b in bodies.items() if b.grounded} + clusters = decompose(ctx.constraints, grounded_ids) + if len(clusters) > 1: + converged = solve_decomposed( + clusters, + bodies, + constraint_objs, + constraint_indices, + params, + ) + else: + converged = _monolithic_solve( + all_residuals, + params, + quat_groups, + ) + else: + converged = _monolithic_solve(all_residuals, params, quat_groups) # 6. DOF dof = count_dof(all_residuals, params) @@ -182,6 +193,26 @@ class KindredSolver(kcsolve.IKCSolver): return True +def _monolithic_solve(all_residuals, params, quat_groups): + """Newton-Raphson solve with BFGS fallback on the full system.""" + converged = newton_solve( + all_residuals, + params, + quat_groups=quat_groups, + max_iter=100, + tol=1e-10, + ) + if not converged: + converged = bfgs_solve( + all_residuals, + params, + quat_groups=quat_groups, + max_iter=200, + tol=1e-10, + ) + return converged + + def _build_constraint( kind, body_i, diff --git a/tests/test_decompose.py b/tests/test_decompose.py new file mode 100644 index 0000000..e4937f3 --- /dev/null +++ b/tests/test_decompose.py @@ -0,0 +1,1052 @@ +"""Tests for graph decomposition and cluster-by-cluster solving.""" + +import math + +import pytest +from kindred_solver.constraints import ( + CoincidentConstraint, + CylindricalConstraint, + FixedConstraint, + ParallelConstraint, + RevoluteConstraint, + SliderConstraint, +) +from kindred_solver.decompose import ( + SolveCluster, + _constraints_for_bodies, + build_constraint_graph_simple, + build_solve_order, + classify_cluster_rigidity, + find_clusters, + residual_count_by_name, + solve_decomposed, +) +from kindred_solver.dof import count_dof +from kindred_solver.entities import RigidBody +from kindred_solver.newton import newton_solve +from kindred_solver.params import ParamTable +from kindred_solver.prepass import single_equation_pass, substitution_pass + +ID_QUAT = (1, 0, 0, 0) +c45 = math.cos(math.pi / 4) +s45 = math.sin(math.pi / 4) +ROT_90Z = (c45, 0, 0, s45) + + +# ============================================================================ +# DOF table tests +# ============================================================================ + + +class TestResidualCount: + def test_known_types(self): + assert residual_count_by_name("Fixed") == 6 + assert residual_count_by_name("Revolute") == 5 + assert residual_count_by_name("Cylindrical") == 4 + assert residual_count_by_name("Ball") == 3 + assert residual_count_by_name("Coincident") == 3 + assert residual_count_by_name("Parallel") == 2 + assert residual_count_by_name("Perpendicular") == 1 + assert residual_count_by_name("DistancePointPoint") == 1 + + def test_stubs_zero(self): + assert residual_count_by_name("Cam") == 0 + assert residual_count_by_name("Slot") == 0 + + def test_unknown_zero(self): + assert residual_count_by_name("DoesNotExist") == 0 + + +# ============================================================================ +# Graph construction tests +# ============================================================================ + + +class TestBuildConstraintGraph: + def test_simple_pair(self): + """Two bodies, one Revolute constraint.""" + G = build_constraint_graph_simple( + [("A", "B", "Revolute", 0)], + grounded={"A"}, + ) + assert set(G.nodes()) == {"A", "B"} + assert G.number_of_edges() == 1 + assert G.nodes["A"]["grounded"] is True + assert G.nodes["B"]["grounded"] is False + + def test_chain(self): + """Chain: A-B-C-D.""" + G = build_constraint_graph_simple( + [ + ("A", "B", "Revolute", 0), + ("B", "C", "Revolute", 1), + ("C", "D", "Revolute", 2), + ], + ) + assert G.number_of_nodes() == 4 + assert G.number_of_edges() == 3 + + def test_multi_edge(self): + """Two constraints between same body pair → two edges.""" + G = build_constraint_graph_simple( + [ + ("A", "B", "Revolute", 0), + ("A", "B", "Parallel", 1), + ], + ) + assert G.number_of_nodes() == 2 + assert G.number_of_edges() == 2 + + def test_stub_excluded(self): + """Stub constraints (Cam, Slot) excluded from graph.""" + G = build_constraint_graph_simple( + [ + ("A", "B", "Revolute", 0), + ("B", "C", "Cam", 1), + ], + ) + assert G.number_of_nodes() == 2 # C not added + assert G.number_of_edges() == 1 + + +# ============================================================================ +# Biconnected decomposition tests +# ============================================================================ + + +class TestFindClusters: + def test_chain_decomposes(self): + """Chain A-B-C-D: 3 biconnected components, 2 articulation points.""" + G = build_constraint_graph_simple( + [ + ("A", "B", "Revolute", 0), + ("B", "C", "Revolute", 1), + ("C", "D", "Revolute", 2), + ], + ) + clusters, artic = find_clusters(G) + assert len(clusters) == 3 + assert artic == {"B", "C"} + + def test_loop_single_cluster(self): + """Loop A-B-C-A: single biconnected component, no articulation points.""" + G = build_constraint_graph_simple( + [ + ("A", "B", "Revolute", 0), + ("B", "C", "Revolute", 1), + ("C", "A", "Revolute", 2), + ], + ) + clusters, artic = find_clusters(G) + assert len(clusters) == 1 + assert artic == set() + + def test_star_decomposes(self): + """Star: center connected to 3 leaves. + + 3 biconnected components (each is center + one leaf). + Center is the only articulation point. + """ + G = build_constraint_graph_simple( + [ + ("center", "L1", "Revolute", 0), + ("center", "L2", "Revolute", 1), + ("center", "L3", "Revolute", 2), + ], + ) + clusters, artic = find_clusters(G) + assert len(clusters) == 3 + assert artic == {"center"} + + def test_single_edge(self): + """Two nodes, one edge: single biconnected component.""" + G = build_constraint_graph_simple( + [("A", "B", "Fixed", 0)], + ) + clusters, artic = find_clusters(G) + assert len(clusters) == 1 + assert artic == set() + + def test_tree_with_loop(self): + """A-B-C-D with a loop B-C-E-B. + + The loop {B, C, E} forms one biconnected component. + A-B and C-D are separate biconnected components. + Articulation points: B and C. + """ + G = build_constraint_graph_simple( + [ + ("A", "B", "Revolute", 0), + ("B", "C", "Revolute", 1), + ("C", "E", "Revolute", 2), + ("E", "B", "Revolute", 3), + ("C", "D", "Revolute", 4), + ], + ) + clusters, artic = find_clusters(G) + # 3 clusters: {A,B}, {B,C,E}, {C,D} + assert len(clusters) == 3 + assert artic == {"B", "C"} + + +# ============================================================================ +# Solve order tests +# ============================================================================ + + +class TestBuildSolveOrder: + def test_chain_grounded_at_start(self): + """Chain: ground-A-B-C. + + Root cluster is {ground,A}, solved first. Then {A,B}, then {B,C}. + """ + G = build_constraint_graph_simple( + [ + ("ground", "A", "Revolute", 0), + ("A", "B", "Revolute", 1), + ("B", "C", "Revolute", 2), + ], + grounded={"ground"}, + ) + clusters, artic = find_clusters(G) + solve_order = build_solve_order(G, clusters, artic, {"ground"}) + + assert len(solve_order) == 3 + # First cluster should contain ground (root) + assert "ground" in solve_order[0].bodies + assert solve_order[0].has_ground is True + # Last cluster should be the leaf + assert "C" in solve_order[-1].bodies + + def test_star_grounded_center(self): + """Star with grounded center: root cluster first, then leaves.""" + G = build_constraint_graph_simple( + [ + ("center", "L1", "Revolute", 0), + ("center", "L2", "Revolute", 1), + ("center", "L3", "Revolute", 2), + ], + grounded={"center"}, + ) + clusters, artic = find_clusters(G) + solve_order = build_solve_order(G, clusters, artic, {"center"}) + + assert len(solve_order) == 3 + # First cluster contains grounded center + assert "center" in solve_order[0].bodies + assert solve_order[0].has_ground is True + # All clusters have center as boundary + for sc in solve_order: + assert "center" in sc.boundary_bodies or "center" in sc.bodies + + def test_single_cluster_no_boundary(self): + """Single cluster: no boundary bodies.""" + G = build_constraint_graph_simple( + [("A", "B", "Revolute", 0)], + grounded={"A"}, + ) + clusters, artic = find_clusters(G) + solve_order = build_solve_order(G, clusters, artic, {"A"}) + + assert len(solve_order) == 1 + assert solve_order[0].boundary_bodies == set() + assert solve_order[0].has_ground is True + + def test_constraint_assignment(self): + """Constraints are correctly assigned to clusters.""" + G = build_constraint_graph_simple( + [ + ("A", "B", "Revolute", 0), + ("B", "C", "Revolute", 1), + ], + ) + clusters, artic = find_clusters(G) + solve_order = build_solve_order(G, clusters, artic, set()) + + # Each cluster should have exactly 1 constraint + all_indices = [] + for sc in solve_order: + all_indices.extend(sc.constraint_indices) + assert sorted(all_indices) == [0, 1] + + +# ============================================================================ +# Integration: cluster solving with actual Newton solver +# ============================================================================ + + +def _monolithic_solve(pt, all_bodies, constraint_objs): + """Monolithic solve for comparison.""" + residuals = [] + for c in constraint_objs: + residuals.extend(c.residuals()) + quat_groups = [] + for body in all_bodies: + if not body.grounded: + residuals.append(body.quat_norm_residual()) + quat_groups.append(body.quat_param_names()) + residuals = substitution_pass(residuals, pt) + residuals = single_equation_pass(residuals, pt) + return newton_solve(residuals, pt, quat_groups=quat_groups, max_iter=100, tol=1e-10) + + +class TestClusterSolve: + def test_chain_solve_matches_monolithic(self): + """Chain: ground → A → B. Decomposed solve matches monolithic. + + Two biconnected components: {ground,A} and {A,B}. + A is the articulation point. + """ + # --- Decomposed solve --- + pt = ParamTable() + ground = RigidBody("ground", pt, (0, 0, 0), ID_QUAT, grounded=True) + body_a = RigidBody("A", pt, (3, 2, 0), ID_QUAT) + body_b = RigidBody("B", pt, (8, 3, 0), ID_QUAT) + bodies = {"ground": ground, "A": body_a, "B": body_b} + + c0 = RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, body_a, (0, 0, 0), ID_QUAT) + c1 = RevoluteConstraint(body_a, (5, 0, 0), ID_QUAT, body_b, (0, 0, 0), ID_QUAT) + constraint_objs = [c0, c1] + constraint_indices = [0, 1] + + # Solve order: grounded cluster first, then outward. + # {ground,A} first (A gets solved), then {A,B} with A as boundary. + clusters = [ + SolveCluster( + bodies={"ground", "A"}, + constraint_indices=[0], + boundary_bodies={"A"}, + has_ground=True, + ), + SolveCluster( + bodies={"A", "B"}, + constraint_indices=[1], + boundary_bodies={"A"}, + has_ground=False, + ), + ] + + converged = solve_decomposed( + clusters, + bodies, + constraint_objs, + constraint_indices, + pt, + ) + assert converged + + env = pt.get_env() + pos_a = body_a.extract_position(env) + pos_b = body_b.extract_position(env) + + # A should be at origin (revolute to ground origin) + assert abs(pos_a[0]) < 1e-8 + assert abs(pos_a[1]) < 1e-8 + assert abs(pos_a[2]) < 1e-8 + + # B should be at (5,0,0) — revolute at A's (5,0,0) marker, B's origin + assert abs(pos_b[0] - 5.0) < 1e-8 + assert abs(pos_b[1]) < 1e-8 + assert abs(pos_b[2]) < 1e-8 + + def test_star_solve(self): + """Star: grounded center with 3 revolute arms. + + Each arm is a separate cluster. Center is articulation point. + Solve center-containing clusters first, then arms. + """ + pt = ParamTable() + center = RigidBody("C", pt, (0, 0, 0), ID_QUAT, grounded=True) + arm1 = RigidBody("L1", pt, (5, 5, 0), ID_QUAT) + arm2 = RigidBody("L2", pt, (5, -5, 0), ID_QUAT) + arm3 = RigidBody("L3", pt, (-5, 5, 0), ID_QUAT) + bodies = {"C": center, "L1": arm1, "L2": arm2, "L3": arm3} + + c0 = RevoluteConstraint(center, (3, 0, 0), ID_QUAT, arm1, (0, 0, 0), ID_QUAT) + c1 = RevoluteConstraint(center, (0, 3, 0), ID_QUAT, arm2, (0, 0, 0), ID_QUAT) + c2 = RevoluteConstraint(center, (-3, 0, 0), ID_QUAT, arm3, (0, 0, 0), ID_QUAT) + constraint_objs = [c0, c1, c2] + constraint_indices = [0, 1, 2] + + # Each arm cluster has center as boundary. Since center is grounded, + # its params are already fixed. Solve order doesn't matter much. + clusters = [ + SolveCluster( + bodies={"C", "L1"}, + constraint_indices=[0], + boundary_bodies={"C"}, + has_ground=True, + ), + SolveCluster( + bodies={"C", "L2"}, + constraint_indices=[1], + boundary_bodies={"C"}, + has_ground=True, + ), + SolveCluster( + bodies={"C", "L3"}, + constraint_indices=[2], + boundary_bodies={"C"}, + has_ground=True, + ), + ] + + converged = solve_decomposed( + clusters, + bodies, + constraint_objs, + constraint_indices, + pt, + ) + assert converged + + env = pt.get_env() + # Each arm should be at its marker position + assert abs(arm1.extract_position(env)[0] - 3.0) < 1e-8 + assert abs(arm2.extract_position(env)[1] - 3.0) < 1e-8 + assert abs(arm3.extract_position(env)[0] + 3.0) < 1e-8 + + def test_single_cluster_solve(self): + """Single cluster: decomposed solve behaves same as monolithic.""" + pt = ParamTable() + ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) + body = RigidBody("b", pt, (5, 5, 5), ID_QUAT) + bodies = {"g": ground, "b": body} + + c0 = RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, body, (0, 0, 0), ID_QUAT) + constraint_objs = [c0] + constraint_indices = [0] + + clusters = [ + SolveCluster( + bodies={"g", "b"}, + constraint_indices=[0], + boundary_bodies=set(), + has_ground=True, + ), + ] + + converged = solve_decomposed( + clusters, + bodies, + constraint_objs, + constraint_indices, + pt, + ) + assert converged + + env = pt.get_env() + pos = body.extract_position(env) + assert abs(pos[0]) < 1e-8 + assert abs(pos[1]) < 1e-8 + assert abs(pos[2]) < 1e-8 + + +# ============================================================================ +# Full decompose() pipeline tests +# ============================================================================ + + +class TestDecomposePipeline: + """Test the full decompose() entry point (requires kcsolve for + build_constraint_graph, so we test via build_constraint_graph_simple + and manual cluster construction instead).""" + + def test_chain_topology(self): + """Chain decomposition produces correct number of clusters.""" + G = build_constraint_graph_simple( + [ + ("g", "A", "Revolute", 0), + ("A", "B", "Revolute", 1), + ("B", "C", "Revolute", 2), + ("C", "D", "Revolute", 3), + ], + grounded={"g"}, + ) + clusters, artic = find_clusters(G) + solve_order = build_solve_order(G, clusters, artic, {"g"}) + + # Chain of 5 nodes → 4 biconnected components + assert len(solve_order) == 4 + # First cluster should contain ground + assert "g" in solve_order[0].bodies + + def test_disconnected_components(self): + """Two disconnected sub-assemblies produce independent cluster groups.""" + G = build_constraint_graph_simple( + [ + ("A", "B", "Revolute", 0), + ("C", "D", "Revolute", 1), + ], + ) + # Two connected components, each single-cluster + components = list(nx.connected_components(G)) + assert len(components) == 2 + + # Each component decomposes to 1 cluster + for comp_nodes in components: + sub = G.subgraph(comp_nodes).copy() + clusters, artic = find_clusters(sub) + assert len(clusters) == 1 + assert artic == set() + + +import networkx as nx +from kindred_solver.bfgs import bfgs_solve + +# ============================================================================ +# Integration: decomposed solve vs monolithic on larger assemblies +# ============================================================================ + + +class TestDecomposedVsMonolithic: + """Verify decomposed solving produces equivalent results to monolithic.""" + + def _build_chain(self, n_bodies, pt, displaced=True): + """Build a revolute chain: ground → B0 → B1 → ... → B(n-1). + + Each body is connected at (2,0,0) local → (0,0,0) local of next. + If displaced, bodies start at slightly wrong positions. + Returns (bodies_dict, constraint_objs, constraint_indices). + """ + ground = RigidBody("ground", pt, (0, 0, 0), ID_QUAT, grounded=True) + bodies = {"ground": ground} + constraints = [] + indices = [] + + prev_body = ground + for i in range(n_bodies): + name = f"B{i}" + # Correct position would be ((i+1)*2, 0, 0), displace slightly + x = (i + 1) * 2 + (0.5 if displaced else 0) + y = 0.3 if displaced else 0 + body = RigidBody(name, pt, (x, y, 0), ID_QUAT) + bodies[name] = body + + c = RevoluteConstraint( + prev_body, + (2, 0, 0), + ID_QUAT, + body, + (0, 0, 0), + ID_QUAT, + ) + constraints.append(c) + indices.append(i) + prev_body = body + + return bodies, constraints, indices + + def test_chain_10_bodies(self): + """10-body revolute chain: decomposed solve converges and satisfies constraints. + + Revolute chains are under-constrained (each joint has 1 rotation DOF), + so monolithic and decomposed may find different valid solutions. + We verify convergence and that all constraint residuals are near zero. + """ + pt_dec = ParamTable() + bodies_dec, constraints_dec, indices_dec = self._build_chain(10, pt_dec) + + # Build clusters manually (chain -> each link is a biconnected component) + cluster_list = [] + body_names = ["ground"] + [f"B{i}" for i in range(10)] + for i in range(10): + pair = {body_names[i], body_names[i + 1]} + boundary = set() + if i > 0: + boundary.add(body_names[i]) # shared with previous cluster + if i < 9: + boundary.add(body_names[i + 1]) # shared with next cluster + cluster_list.append( + SolveCluster( + bodies=pair, + constraint_indices=[i], + boundary_bodies=boundary, + has_ground=(i == 0), + ) + ) + + converged = solve_decomposed( + cluster_list, + bodies_dec, + constraints_dec, + indices_dec, + pt_dec, + ) + assert converged + + # Verify all constraint residuals are satisfied + env = pt_dec.get_env() + for c in constraints_dec: + for r in c.residuals(): + val = r.eval(env) + assert abs(val) < 1e-8, f"Residual not satisfied: {val}" + + # Verify quat norms + for name, body in bodies_dec.items(): + if not body.grounded: + qn = body.quat_norm_residual().eval(env) + assert abs(qn) < 1e-8, f"{name} quat not normalized: {qn}" + + def test_star_5_arms(self): + """Star with 5 revolute arms off a grounded center.""" + pt = ParamTable() + center = RigidBody("C", pt, (0, 0, 0), ID_QUAT, grounded=True) + bodies = {"C": center} + constraints = [] + indices = [] + clusters = [] + + import math as m + + for i in range(5): + name = f"arm{i}" + angle = 2 * m.pi * i / 5 + marker_pos = (3 * m.cos(angle), 3 * m.sin(angle), 0) + body = RigidBody( + name, pt, (marker_pos[0] + 1, marker_pos[1] + 1, 0), ID_QUAT + ) + bodies[name] = body + + c = RevoluteConstraint( + center, + marker_pos, + ID_QUAT, + body, + (0, 0, 0), + ID_QUAT, + ) + constraints.append(c) + indices.append(i) + + clusters.append( + SolveCluster( + bodies={"C", name}, + constraint_indices=[i], + boundary_bodies={"C"}, + has_ground=True, + ) + ) + + converged = solve_decomposed(clusters, bodies, constraints, indices, pt) + assert converged + + env = pt.get_env() + for i in range(5): + name = f"arm{i}" + angle = 2 * m.pi * i / 5 + expected_x = 3 * m.cos(angle) + expected_y = 3 * m.sin(angle) + pos = bodies[name].extract_position(env) + assert abs(pos[0] - expected_x) < 1e-6, ( + f"{name}: x={pos[0]:.4f} expected={expected_x:.4f}" + ) + assert abs(pos[1] - expected_y) < 1e-6, ( + f"{name}: y={pos[1]:.4f} expected={expected_y:.4f}" + ) + assert abs(pos[2]) < 1e-6 + + def test_single_cluster_loop(self): + """A loop assembly (single biconnected component) solves as one cluster.""" + pt = ParamTable() + ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) + b1 = RigidBody("B1", pt, (2.5, 0.3, 0), ID_QUAT) + b2 = RigidBody("B2", pt, (4.5, -0.2, 0), ID_QUAT) + bodies = {"g": ground, "B1": b1, "B2": b2} + + c0 = RevoluteConstraint(ground, (2, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT) + c1 = RevoluteConstraint(b1, (2, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + c2 = RevoluteConstraint(b2, (2, 0, 0), ID_QUAT, ground, (6, 0, 0), ID_QUAT) + constraint_objs = [c0, c1, c2] + constraint_indices = [0, 1, 2] + + clusters = [ + SolveCluster( + bodies={"g", "B1", "B2"}, + constraint_indices=[0, 1, 2], + boundary_bodies=set(), + has_ground=True, + ), + ] + + converged = solve_decomposed( + clusters, bodies, constraint_objs, constraint_indices, pt + ) + assert converged + + env = pt.get_env() + for c in constraint_objs: + for r in c.residuals(): + assert abs(r.eval(env)) < 1e-8 + + def test_boundary_propagation_accuracy(self): + """Verify boundary body values propagate correctly between clusters. + + Chain: ground → A → B → C. A is boundary between clusters 0-1. + B is boundary between clusters 1-2. After decomposed solve, + all revolute constraints should be satisfied. + """ + pt = ParamTable() + ground = RigidBody("ground", pt, (0, 0, 0), ID_QUAT, grounded=True) + body_a = RigidBody("A", pt, (2.5, 0.5, 0), ID_QUAT) # displaced + body_b = RigidBody("B", pt, (4.5, -0.5, 0), ID_QUAT) # displaced + body_c = RigidBody("C", pt, (7.0, 1.0, 0), ID_QUAT) # displaced + bodies = {"ground": ground, "A": body_a, "B": body_b, "C": body_c} + + c0 = RevoluteConstraint(ground, (2, 0, 0), ID_QUAT, body_a, (0, 0, 0), ID_QUAT) + c1 = RevoluteConstraint(body_a, (2, 0, 0), ID_QUAT, body_b, (0, 0, 0), ID_QUAT) + c2 = RevoluteConstraint(body_b, (2, 0, 0), ID_QUAT, body_c, (0, 0, 0), ID_QUAT) + constraint_objs = [c0, c1, c2] + constraint_indices = [0, 1, 2] + + clusters = [ + SolveCluster( + bodies={"ground", "A"}, + constraint_indices=[0], + boundary_bodies={"A"}, + has_ground=True, + ), + SolveCluster( + bodies={"A", "B"}, + constraint_indices=[1], + boundary_bodies={"A", "B"}, + has_ground=False, + ), + SolveCluster( + bodies={"B", "C"}, + constraint_indices=[2], + boundary_bodies={"B"}, + has_ground=False, + ), + ] + + converged = solve_decomposed( + clusters, bodies, constraint_objs, constraint_indices, pt + ) + assert converged + + env = pt.get_env() + # A at (2,0,0), B at (4,0,0), C at (6,0,0) + pos_a = body_a.extract_position(env) + pos_b = body_b.extract_position(env) + pos_c = body_c.extract_position(env) + + assert abs(pos_a[0] - 2.0) < 1e-6 + assert abs(pos_a[1]) < 1e-6 + assert abs(pos_b[0] - 4.0) < 1e-6 + assert abs(pos_b[1]) < 1e-6 + assert abs(pos_c[0] - 6.0) < 1e-6 + assert abs(pos_c[1]) < 1e-6 + + # Verify all params are free after solve (unfixed correctly) + for name in ["A", "B", "C"]: + body = bodies[name] + for pname in body._param_names: + assert not pt.is_fixed(pname), f"{pname} still fixed after solve" + + +# ============================================================================ +# Pebble game rigidity classification +# ============================================================================ + + +class TestPebbleGameClassification: + """Test classify_cluster_rigidity using the pebble game.""" + + def test_well_constrained_fixed_pair(self): + """Two bodies with a Fixed joint (6 residuals) → well-constrained.""" + G = build_constraint_graph_simple( + [("ground", "B", "Fixed", 0)], + grounded={"ground"}, + ) + cluster = SolveCluster( + bodies={"ground", "B"}, + constraint_indices=[0], + boundary_bodies=set(), + has_ground=True, + ) + result = classify_cluster_rigidity(cluster, G, {"ground"}) + assert result == "well-constrained", f"Expected well-constrained, got {result}" + + def test_underconstrained_revolute_pair(self): + """Ground + body with Revolute (5 residuals) → 1 DOF under-constrained.""" + G = build_constraint_graph_simple( + [("ground", "B", "Revolute", 0)], + grounded={"ground"}, + ) + cluster = SolveCluster( + bodies={"ground", "B"}, + constraint_indices=[0], + boundary_bodies=set(), + has_ground=True, + ) + result = classify_cluster_rigidity(cluster, G, {"ground"}) + assert result == "underconstrained", f"Expected underconstrained, got {result}" + + def test_overconstrained(self): + """Two Fixed joints between same pair → overconstrained.""" + G = build_constraint_graph_simple( + [ + ("ground", "B", "Fixed", 0), + ("ground", "B", "Fixed", 1), + ], + grounded={"ground"}, + ) + cluster = SolveCluster( + bodies={"ground", "B"}, + constraint_indices=[0, 1], + boundary_bodies=set(), + has_ground=True, + ) + result = classify_cluster_rigidity(cluster, G, {"ground"}) + assert result == "overconstrained", f"Expected overconstrained, got {result}" + + def test_manual_edge_types(self): + """Parallel (2 residuals, manual edge) is correctly handled.""" + G = build_constraint_graph_simple( + [ + ("ground", "B", "Coincident", 0), # 3 residuals + ("ground", "B", "Parallel", 1), # 2 residuals (manual) + ], + grounded={"ground"}, + ) + # Total: 3 + 2 = 5 residuals (like Revolute) → underconstrained (1 DOF) + cluster = SolveCluster( + bodies={"ground", "B"}, + constraint_indices=[0, 1], + boundary_bodies=set(), + has_ground=True, + ) + result = classify_cluster_rigidity(cluster, G, {"ground"}) + assert result == "underconstrained", f"Expected underconstrained, got {result}" + + def test_no_ground(self): + """Two ungrounded bodies with Fixed joint → well-constrained (6 DOF trivial).""" + G = build_constraint_graph_simple( + [("A", "B", "Fixed", 0)], + ) + cluster = SolveCluster( + bodies={"A", "B"}, + constraint_indices=[0], + boundary_bodies=set(), + has_ground=False, + ) + result = classify_cluster_rigidity(cluster, G, set()) + # Two bodies, no ground: 12 DOF total, 6 from Fixed, 6 trivial → well-constrained + assert result == "well-constrained", f"Expected well-constrained, got {result}" + + +# ============================================================================ +# Large assembly tests (20+ bodies) +# ============================================================================ + + +class TestLargeAssembly: + """End-to-end tests with large synthetic assemblies.""" + + def test_chain_20_bodies(self): + """20-body revolute chain: decomposed solve converges.""" + pt = ParamTable() + ground = RigidBody("ground", pt, (0, 0, 0), ID_QUAT, grounded=True) + bodies = {"ground": ground} + constraints = [] + indices = [] + + prev = ground + for i in range(20): + name = f"B{i}" + body = RigidBody(name, pt, ((i + 1) * 2 + 0.3, 0.2, 0), ID_QUAT) + bodies[name] = body + c = RevoluteConstraint(prev, (2, 0, 0), ID_QUAT, body, (0, 0, 0), ID_QUAT) + constraints.append(c) + indices.append(i) + prev = body + + # Build clusters: chain of 20 biconnected components + body_names = ["ground"] + [f"B{i}" for i in range(20)] + cluster_list = [] + for i in range(20): + pair = {body_names[i], body_names[i + 1]} + boundary = set() + if i > 0: + boundary.add(body_names[i]) + if i < 19: + boundary.add(body_names[i + 1]) + cluster_list.append( + SolveCluster( + bodies=pair, + constraint_indices=[i], + boundary_bodies=boundary, + has_ground=(i == 0), + ) + ) + + converged = solve_decomposed(cluster_list, bodies, constraints, indices, pt) + assert converged + + env = pt.get_env() + for c in constraints: + for r in c.residuals(): + assert abs(r.eval(env)) < 1e-8 + + def test_star_10_arms(self): + """Star with 10 revolute arms: 10 independent clusters.""" + pt = ParamTable() + center = RigidBody("C", pt, (0, 0, 0), ID_QUAT, grounded=True) + bodies = {"C": center} + constraints = [] + indices = [] + clusters = [] + + for i in range(10): + name = f"arm{i}" + angle = 2 * math.pi * i / 10 + mx = 3 * math.cos(angle) + my = 3 * math.sin(angle) + body = RigidBody(name, pt, (mx + 0.5, my + 0.5, 0), ID_QUAT) + bodies[name] = body + c = RevoluteConstraint( + center, (mx, my, 0), ID_QUAT, body, (0, 0, 0), ID_QUAT + ) + constraints.append(c) + indices.append(i) + clusters.append( + SolveCluster( + bodies={"C", name}, + constraint_indices=[i], + boundary_bodies={"C"}, + has_ground=True, + ) + ) + + converged = solve_decomposed(clusters, bodies, constraints, indices, pt) + assert converged + + env = pt.get_env() + for i in range(10): + name = f"arm{i}" + angle = 2 * math.pi * i / 10 + pos = bodies[name].extract_position(env) + assert abs(pos[0] - 3 * math.cos(angle)) < 1e-6 + assert abs(pos[1] - 3 * math.sin(angle)) < 1e-6 + + def test_tree_assembly(self): + """Tree: ground → A → (B, C), A → D. Mixed branching. + + Topology: ground-A is root cluster, A is articulation point. + Three leaf clusters: {A,B}, {A,C}, {A,D}. + """ + pt = ParamTable() + ground = RigidBody("ground", pt, (0, 0, 0), ID_QUAT, grounded=True) + a = RigidBody("A", pt, (2.5, 0.3, 0), ID_QUAT) + b = RigidBody("B", pt, (4.5, 2.5, 0), ID_QUAT) + c = RigidBody("C", pt, (4.5, -2.5, 0), ID_QUAT) + d = RigidBody("D", pt, (4.5, 0.3, 0), ID_QUAT) + bodies = {"ground": ground, "A": a, "B": b, "C": c, "D": d} + + c0 = RevoluteConstraint(ground, (2, 0, 0), ID_QUAT, a, (0, 0, 0), ID_QUAT) + c1 = RevoluteConstraint(a, (2, 2, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT) + c2 = RevoluteConstraint(a, (2, -2, 0), ID_QUAT, c, (0, 0, 0), ID_QUAT) + c3 = RevoluteConstraint(a, (2, 0, 0), ID_QUAT, d, (0, 0, 0), ID_QUAT) + constraint_objs = [c0, c1, c2, c3] + constraint_indices = [0, 1, 2, 3] + + # Root cluster: {ground, A}, then 3 leaf clusters + clusters = [ + SolveCluster( + bodies={"ground", "A"}, + constraint_indices=[0], + boundary_bodies={"A"}, + has_ground=True, + ), + SolveCluster( + bodies={"A", "B"}, + constraint_indices=[1], + boundary_bodies={"A"}, + has_ground=False, + ), + SolveCluster( + bodies={"A", "C"}, + constraint_indices=[2], + boundary_bodies={"A"}, + has_ground=False, + ), + SolveCluster( + bodies={"A", "D"}, + constraint_indices=[3], + boundary_bodies={"A"}, + has_ground=False, + ), + ] + + converged = solve_decomposed( + clusters, bodies, constraint_objs, constraint_indices, pt + ) + assert converged + + env = pt.get_env() + # A should be at (2, 0, 0) + pos_a = a.extract_position(env) + assert abs(pos_a[0] - 2.0) < 1e-6 + assert abs(pos_a[1]) < 1e-6 + + # B at A + marker (2,2,0) = (4,2,0) + pos_b = b.extract_position(env) + assert abs(pos_b[0] - 4.0) < 1e-6 + assert abs(pos_b[1] - 2.0) < 1e-6 + + # C at A + marker (2,-2,0) = (4,-2,0) + pos_c = c.extract_position(env) + assert abs(pos_c[0] - 4.0) < 1e-6 + assert abs(pos_c[1] + 2.0) < 1e-6 + + # D at A + marker (2,0,0) = (4,0,0) + pos_d = d.extract_position(env) + assert abs(pos_d[0] - 4.0) < 1e-6 + assert abs(pos_d[1]) < 1e-6 + + def test_mixed_constraint_chain(self): + """Chain with mixed constraint types: Fixed, Coincident, Cylindrical.""" + pt = ParamTable() + ground = RigidBody("ground", pt, (0, 0, 0), ID_QUAT, grounded=True) + b1 = RigidBody("B1", pt, (2.5, 0.3, 0), ID_QUAT) + b2 = RigidBody("B2", pt, (5.5, -0.2, 0), ID_QUAT) + b3 = RigidBody("B3", pt, (8.5, 0.4, 0), ID_QUAT) + bodies = {"ground": ground, "B1": b1, "B2": b2, "B3": b3} + + c0 = FixedConstraint(ground, (2, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT) + c1 = CoincidentConstraint(b1, (3, 0, 0), b2, (0, 0, 0)) + c2 = CylindricalConstraint(b2, (3, 0, 0), ID_QUAT, b3, (0, 0, 0), ID_QUAT) + constraint_objs = [c0, c1, c2] + constraint_indices = [0, 1, 2] + + clusters = [ + SolveCluster( + bodies={"ground", "B1"}, + constraint_indices=[0], + boundary_bodies={"B1"}, + has_ground=True, + ), + SolveCluster( + bodies={"B1", "B2"}, + constraint_indices=[1], + boundary_bodies={"B1", "B2"}, + has_ground=False, + ), + SolveCluster( + bodies={"B2", "B3"}, + constraint_indices=[2], + boundary_bodies={"B2"}, + has_ground=False, + ), + ] + + converged = solve_decomposed( + clusters, bodies, constraint_objs, constraint_indices, pt + ) + assert converged + + env = pt.get_env() + for c in constraint_objs: + for r in c.residuals(): + assert abs(r.eval(env)) < 1e-8 diff --git a/tests/test_params.py b/tests/test_params.py index f37912c..d005c84 100644 --- a/tests/test_params.py +++ b/tests/test_params.py @@ -65,3 +65,37 @@ class TestParamTable: pt.add("b", 0.0, fixed=True) pt.add("c", 0.0) assert pt.n_free() == 2 + + def test_unfix(self): + pt = ParamTable() + pt.add("a", 1.0) + pt.add("b", 2.0) + pt.fix("a") + assert pt.is_fixed("a") + assert "a" not in pt.free_names() + + pt.unfix("a") + assert not pt.is_fixed("a") + assert "a" in pt.free_names() + assert pt.n_free() == 2 + + def test_fix_unfix_roundtrip(self): + """Fix then unfix preserves value and makes param free again.""" + pt = ParamTable() + pt.add("x", 5.0) + pt.add("y", 3.0) + pt.fix("x") + pt.set_value("x", 10.0) + pt.unfix("x") + assert pt.get_value("x") == 10.0 + assert "x" in pt.free_names() + # x moves to end of free list + assert pt.free_names() == ["y", "x"] + + def test_unfix_noop_if_already_free(self): + """Unfixing a free parameter is a no-op.""" + pt = ParamTable() + pt.add("a", 1.0) + pt.unfix("a") + assert pt.free_names() == ["a"] + assert pt.n_free() == 1