"""Graph decomposition for cluster-by-cluster constraint solving. Builds a constraint graph from the SolveContext, decomposes it into biconnected components (rigid clusters), orders them via a block-cut tree, and solves each cluster independently. Articulation-point bodies are temporarily fixed when solving adjacent clusters so their solved values propagate as boundary conditions. Requires: networkx """ from __future__ import annotations import importlib.util import logging import sys import types as stdlib_types from collections import deque from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, List import networkx as nx from .bfgs import bfgs_solve from .newton import newton_solve from .prepass import substitution_pass if TYPE_CHECKING: from .constraints import ConstraintBase from .entities import RigidBody from .params import ParamTable log = logging.getLogger(__name__) # --------------------------------------------------------------------------- # DOF table: BaseJointKind → number of residuals (= DOF removed) # --------------------------------------------------------------------------- # Imported lazily to avoid hard kcsolve dependency in tests. # Use residual_count() accessor instead of this dict directly. _RESIDUAL_COUNT: dict[str, int] | None = None def _ensure_residual_count() -> dict: """Build the residual count table on first use.""" global _RESIDUAL_COUNT if _RESIDUAL_COUNT is not None: return _RESIDUAL_COUNT import kcsolve _RESIDUAL_COUNT = { kcsolve.BaseJointKind.Fixed: 6, kcsolve.BaseJointKind.Coincident: 3, kcsolve.BaseJointKind.Ball: 3, kcsolve.BaseJointKind.Revolute: 5, kcsolve.BaseJointKind.Cylindrical: 4, kcsolve.BaseJointKind.Slider: 5, kcsolve.BaseJointKind.Screw: 5, kcsolve.BaseJointKind.Universal: 4, kcsolve.BaseJointKind.Parallel: 2, kcsolve.BaseJointKind.Perpendicular: 1, kcsolve.BaseJointKind.Angle: 1, kcsolve.BaseJointKind.Concentric: 4, kcsolve.BaseJointKind.Tangent: 1, kcsolve.BaseJointKind.Planar: 3, kcsolve.BaseJointKind.LineInPlane: 2, kcsolve.BaseJointKind.PointOnLine: 2, kcsolve.BaseJointKind.PointInPlane: 1, kcsolve.BaseJointKind.DistancePointPoint: 1, kcsolve.BaseJointKind.Gear: 1, kcsolve.BaseJointKind.RackPinion: 1, kcsolve.BaseJointKind.Cam: 0, kcsolve.BaseJointKind.Slot: 0, kcsolve.BaseJointKind.DistanceCylSph: 0, } return _RESIDUAL_COUNT def residual_count(kind) -> int: """Number of residuals a constraint type produces.""" return _ensure_residual_count().get(kind, 0) # --------------------------------------------------------------------------- # Standalone residual-count table (no kcsolve dependency, string-keyed) # Used by tests that don't have kcsolve available. # --------------------------------------------------------------------------- _RESIDUAL_COUNT_BY_NAME: dict[str, int] = { "Fixed": 6, "Coincident": 3, "Ball": 3, "Revolute": 5, "Cylindrical": 4, "Slider": 5, "Screw": 5, "Universal": 4, "Parallel": 2, "Perpendicular": 1, "Angle": 1, "Concentric": 4, "Tangent": 1, "Planar": 3, "LineInPlane": 2, "PointOnLine": 2, "PointInPlane": 1, "DistancePointPoint": 1, "Gear": 1, "RackPinion": 1, "Cam": 0, "Slot": 0, "DistanceCylSph": 0, } def residual_count_by_name(kind_name: str) -> int: """Number of residuals by constraint type name (no kcsolve needed).""" return _RESIDUAL_COUNT_BY_NAME.get(kind_name, 0) # --------------------------------------------------------------------------- # Data structures # --------------------------------------------------------------------------- @dataclass class SolveCluster: """A cluster of bodies to solve together.""" bodies: set[str] # Body IDs in this cluster constraint_indices: list[int] # Indices into the constraint list boundary_bodies: set[str] # Articulation points shared with other clusters has_ground: bool # Whether any body in the cluster is grounded # --------------------------------------------------------------------------- # Graph construction # --------------------------------------------------------------------------- def build_constraint_graph( constraints: list, grounded_bodies: set[str], ) -> nx.MultiGraph: """Build a body-level constraint multigraph. Nodes: part_id strings (one per body referenced by constraints). Edges: one per active constraint with attributes: - constraint_index: position in the constraints list - weight: number of residuals Grounded bodies are tagged with ``grounded=True``. Constraints with 0 residuals (stubs) are excluded. """ G = nx.MultiGraph() for idx, c in enumerate(constraints): if not c.activated: continue weight = residual_count(c.type) if weight == 0: continue part_i = c.part_i part_j = c.part_j # Ensure nodes exist if part_i not in G: G.add_node(part_i, grounded=(part_i in grounded_bodies)) if part_j not in G: G.add_node(part_j, grounded=(part_j in grounded_bodies)) # Store kind_name for pebble game integration kind_name = c.type.name if hasattr(c.type, "name") else str(c.type) G.add_edge( part_i, part_j, constraint_index=idx, weight=weight, kind_name=kind_name ) return G def build_constraint_graph_simple( edges: list[tuple[str, str, str, int]], grounded: set[str] | None = None, ) -> nx.MultiGraph: """Build a constraint graph from simple edge tuples (for testing). Each edge is ``(body_i, body_j, kind_name, constraint_index)``. """ grounded = grounded or set() G = nx.MultiGraph() for body_i, body_j, kind_name, idx in edges: weight = residual_count_by_name(kind_name) if weight == 0: continue if body_i not in G: G.add_node(body_i, grounded=(body_i in grounded)) if body_j not in G: G.add_node(body_j, grounded=(body_j in grounded)) G.add_edge( body_i, body_j, constraint_index=idx, weight=weight, kind_name=kind_name ) return G # --------------------------------------------------------------------------- # Decomposition # --------------------------------------------------------------------------- def find_clusters( G: nx.MultiGraph, ) -> tuple[list[set[str]], set[str]]: """Find biconnected components and articulation points. Returns: clusters: list of body-ID sets (one per biconnected component) articulation_points: body-IDs shared between clusters """ # biconnected_components requires a simple Graph simple = nx.Graph(G) clusters = [set(c) for c in nx.biconnected_components(simple)] artic = set(nx.articulation_points(simple)) return clusters, artic def build_solve_order( G: nx.MultiGraph, clusters: list[set[str]], articulation_points: set[str], grounded_bodies: set[str], ) -> list[SolveCluster]: """Order clusters for solving via the block-cut tree. Builds the block-cut tree (bipartite graph of clusters and articulation points), roots it at a grounded cluster, and returns clusters in root-to-leaf order (grounded first, outward to leaves). This ensures boundary bodies are solved before clusters that depend on them. """ if not clusters: return [] # Single cluster — no ordering needed if len(clusters) == 1: bodies = clusters[0] indices = _constraints_for_bodies(G, bodies) has_ground = bool(bodies & grounded_bodies) return [ SolveCluster( bodies=bodies, constraint_indices=indices, boundary_bodies=set(), has_ground=has_ground, ) ] # Build block-cut tree # Nodes: ("C", i) for cluster i, ("A", body_id) for articulation points bct = nx.Graph() for i, cluster in enumerate(clusters): bct.add_node(("C", i)) for ap in articulation_points: if ap in cluster: bct.add_edge(("C", i), ("A", ap)) # Find root: prefer a cluster containing a grounded body root = ("C", 0) for i, cluster in enumerate(clusters): if cluster & grounded_bodies: root = ("C", i) break # BFS from root: grounded cluster first, outward to leaves visited = set() order = [] queue = deque([root]) visited.add(root) while queue: node = queue.popleft() if node[0] == "C": order.append(node[1]) for neighbor in bct.neighbors(node): if neighbor not in visited: visited.add(neighbor) queue.append(neighbor) # Build SolveCluster objects solve_clusters = [] for i in order: bodies = clusters[i] indices = _constraints_for_bodies(G, bodies) boundary = bodies & articulation_points has_ground = bool(bodies & grounded_bodies) solve_clusters.append( SolveCluster( bodies=bodies, constraint_indices=indices, boundary_bodies=boundary, has_ground=has_ground, ) ) return solve_clusters def _constraints_for_bodies(G: nx.MultiGraph, bodies: set[str]) -> list[int]: """Collect constraint indices for edges where both endpoints are in bodies.""" indices = [] seen = set() for u, v, data in G.edges(data=True): idx = data["constraint_index"] if idx in seen: continue if u in bodies and v in bodies: seen.add(idx) indices.append(idx) return sorted(indices) # --------------------------------------------------------------------------- # Top-level decompose entry point # --------------------------------------------------------------------------- def decompose( constraints: list, grounded_bodies: set[str], ) -> list[SolveCluster]: """Full decomposition pipeline: graph → clusters → solve order. Returns a list of SolveCluster in solve order (leaves first). If the system is a single cluster, returns a 1-element list. """ G = build_constraint_graph(constraints, grounded_bodies) # Handle disconnected sub-assemblies all_clusters = [] for component_nodes in nx.connected_components(G): sub = G.subgraph(component_nodes).copy() clusters, artic = find_clusters(sub) if len(clusters) <= 1: # Single cluster in this component bodies = component_nodes if not clusters else clusters[0] indices = _constraints_for_bodies(sub, bodies) has_ground = bool(bodies & grounded_bodies) all_clusters.append( SolveCluster( bodies=set(bodies), constraint_indices=indices, boundary_bodies=set(), has_ground=has_ground, ) ) else: ordered = build_solve_order(sub, clusters, artic, grounded_bodies) all_clusters.extend(ordered) return all_clusters # --------------------------------------------------------------------------- # Cluster solver # --------------------------------------------------------------------------- def solve_decomposed( clusters: list[SolveCluster], bodies: dict[str, "RigidBody"], constraint_objs: list["ConstraintBase"], constraint_indices_map: list[int], params: "ParamTable", ) -> bool: """Solve clusters in order, fixing boundary bodies between solves. Args: clusters: SolveCluster list in solve order (from decompose()). bodies: part_id → RigidBody mapping. constraint_objs: constraint objects (parallel to constraint_indices_map). constraint_indices_map: for each constraint_obj, its index in ctx.constraints. params: shared ParamTable. Returns True if all clusters converged. """ log.info( "solve_decomposed: %d clusters, %d bodies, %d constraints", len(clusters), len(bodies), len(constraint_objs), ) # Build reverse map: constraint_index → position in constraint_objs list idx_to_obj: dict[int, "ConstraintBase"] = {} for pos, ci in enumerate(constraint_indices_map): idx_to_obj[ci] = constraint_objs[pos] solved_bodies: set[str] = set() all_converged = True for cluster_idx, cluster in enumerate(clusters): # 1. Fix boundary bodies that were already solved fixed_boundary_params: list[str] = [] for body_id in cluster.boundary_bodies: if body_id in solved_bodies: body = bodies[body_id] for pname in body._param_names: if not params.is_fixed(pname): params.fix(pname) fixed_boundary_params.append(pname) # 2. Collect residuals for this cluster cluster_residuals = [] for ci in cluster.constraint_indices: obj = idx_to_obj.get(ci) if obj is not None: cluster_residuals.extend(obj.residuals()) # 3. Add quat norm residuals for free, non-grounded bodies in this cluster quat_groups = [] for body_id in cluster.bodies: body = bodies[body_id] if body.grounded: continue if body_id in cluster.boundary_bodies and body_id in solved_bodies: continue # Already fixed as boundary cluster_residuals.append(body.quat_norm_residual()) quat_groups.append(body.quat_param_names()) # 4. Substitution pass (compiles fixed boundary params to constants) cluster_residuals = substitution_pass(cluster_residuals, params) # 5. Newton solve (+ BFGS fallback) if cluster_residuals: log.debug( " cluster[%d]: %d bodies (%d boundary), %d constraints, %d residuals", cluster_idx, len(cluster.bodies), len(cluster.boundary_bodies), len(cluster.constraint_indices), len(cluster_residuals), ) converged = newton_solve( cluster_residuals, params, quat_groups=quat_groups, max_iter=100, tol=1e-10, ) if not converged: log.info( " cluster[%d]: Newton-Raphson failed, trying BFGS", cluster_idx ) converged = bfgs_solve( cluster_residuals, params, quat_groups=quat_groups, max_iter=200, tol=1e-10, ) if not converged: log.warning(" cluster[%d]: failed to converge", cluster_idx) all_converged = False # 6. Mark this cluster's bodies as solved solved_bodies.update(cluster.bodies) # 7. Unfix boundary params for pname in fixed_boundary_params: params.unfix(pname) return all_converged # --------------------------------------------------------------------------- # Pebble game integration (rigidity classification) # --------------------------------------------------------------------------- _PEBBLE_MODULES_LOADED = False _PebbleGame3D = None _PebbleJointType = None _PebbleJoint = None def _load_pebble_modules(): """Lazily load PebbleGame3D and related types from GNN/solver/datagen/. The GNN package has its own import structure (``from solver.datagen.types import ...``) that conflicts with the top-level module layout, so we register shim modules in ``sys.modules`` to make it work. """ global _PEBBLE_MODULES_LOADED, _PebbleGame3D, _PebbleJointType, _PebbleJoint if _PEBBLE_MODULES_LOADED: return # Find GNN/solver/datagen relative to this package pkg_dir = Path(__file__).resolve().parent.parent # mods/solver/ datagen_dir = pkg_dir / "GNN" / "solver" / "datagen" if not datagen_dir.exists(): log.warning("GNN/solver/datagen/ not found; pebble game unavailable") _PEBBLE_MODULES_LOADED = True return # Register shim modules so ``from solver.datagen.types import ...`` works if "solver" not in sys.modules: sys.modules["solver"] = stdlib_types.ModuleType("solver") if "solver.datagen" not in sys.modules: dg = stdlib_types.ModuleType("solver.datagen") sys.modules["solver.datagen"] = dg sys.modules["solver"].datagen = dg # type: ignore[attr-defined] # Load types.py types_path = datagen_dir / "types.py" spec_t = importlib.util.spec_from_file_location( "solver.datagen.types", str(types_path) ) types_mod = importlib.util.module_from_spec(spec_t) sys.modules["solver.datagen.types"] = types_mod spec_t.loader.exec_module(types_mod) # Load pebble_game.py pg_path = datagen_dir / "pebble_game.py" spec_p = importlib.util.spec_from_file_location( "solver.datagen.pebble_game", str(pg_path) ) pg_mod = importlib.util.module_from_spec(spec_p) sys.modules["solver.datagen.pebble_game"] = pg_mod spec_p.loader.exec_module(pg_mod) _PebbleGame3D = pg_mod.PebbleGame3D _PebbleJointType = types_mod.JointType _PebbleJoint = types_mod.Joint _PEBBLE_MODULES_LOADED = True # BaseJointKind name → PebbleGame JointType name. # Types not listed here use manual edge insertion with the residual count. _KIND_NAME_TO_PEBBLE_NAME: dict[str, str] = { "Fixed": "FIXED", "Coincident": "BALL", # Same DOF count (3) "Ball": "BALL", "Revolute": "REVOLUTE", "Cylindrical": "CYLINDRICAL", "Slider": "SLIDER", "Screw": "SCREW", "Universal": "UNIVERSAL", "Planar": "PLANAR", "Perpendicular": "PERPENDICULAR", "DistancePointPoint": "DISTANCE", } # Parallel: pebble game uses 3 DOF, but our solver uses 2. # We handle it with manual edge insertion. # Types that need manual edge insertion (no direct JointType mapping, # or DOF mismatch like Parallel). _MANUAL_EDGE_TYPES: set[str] = { "Parallel", # 2 residuals, but JointType.PARALLEL = 3 "Angle", # 1 residual, no JointType "Concentric", # 4 residuals, no JointType "Tangent", # 1 residual, no JointType "LineInPlane", # 2 residuals, no JointType "PointOnLine", # 2 residuals, no JointType "PointInPlane", # 1 residual, no JointType "Gear", # 1 residual, no JointType "RackPinion", # 1 residual, no JointType } _GROUND_BODY_ID = -1 def classify_cluster_rigidity( cluster: SolveCluster, constraint_graph: nx.MultiGraph, grounded_bodies: set[str], ) -> str | None: """Run pebble game on a cluster and return rigidity classification. Returns one of: "well-constrained", "underconstrained", "overconstrained", "mixed", or None if pebble game unavailable. """ import numpy as np _load_pebble_modules() if _PebbleGame3D is None: return None pg = _PebbleGame3D() # Map string body IDs → integer IDs for pebble game body_list = sorted(cluster.bodies) body_to_int: dict[str, int] = {b: i for i, b in enumerate(body_list)} for b in body_list: pg.add_body(body_to_int[b]) # Add virtual ground body if cluster has grounded bodies has_ground = bool(cluster.bodies & grounded_bodies) if has_ground: pg.add_body(_GROUND_BODY_ID) for b in cluster.bodies & grounded_bodies: ground_joint = _PebbleJoint( joint_id=-1, body_a=body_to_int[b], body_b=_GROUND_BODY_ID, joint_type=_PebbleJointType["FIXED"], anchor_a=np.zeros(3), anchor_b=np.zeros(3), ) pg.add_joint(ground_joint) # Add constraint edges joint_counter = 0 zero = np.zeros(3) for u, v, data in constraint_graph.edges(data=True): if u not in cluster.bodies or v not in cluster.bodies: continue ci = data["constraint_index"] if ci not in cluster.constraint_indices: continue # Determine the constraint kind name from the graph edge kind_name = data.get("kind_name", "") n_residuals = data.get("weight", 0) if not kind_name or n_residuals == 0: continue int_u = body_to_int[u] int_v = body_to_int[v] pebble_name = _KIND_NAME_TO_PEBBLE_NAME.get(kind_name) if pebble_name and kind_name not in _MANUAL_EDGE_TYPES: # Direct JointType mapping jt = _PebbleJointType[pebble_name] joint = _PebbleJoint( joint_id=joint_counter, body_a=int_u, body_b=int_v, joint_type=jt, anchor_a=zero, anchor_b=zero, ) pg.add_joint(joint) joint_counter += 1 else: # Manual edge insertion: one DISTANCE edge per residual for _ in range(n_residuals): joint = _PebbleJoint( joint_id=joint_counter, body_a=int_u, body_b=int_v, joint_type=_PebbleJointType["DISTANCE"], anchor_a=zero, anchor_b=zero, ) pg.add_joint(joint) joint_counter += 1 # Classify using raw pebble counts (adjusting for virtual ground) total_dof = pg.get_dof() redundant = pg.get_redundant_count() # The virtual ground body contributes 6 pebbles that are never consumed. # Subtract them to get the effective DOF. if has_ground: total_dof -= 6 # virtual ground's unconstrained pebbles baseline = 0 else: baseline = 6 # trivial rigid-body motion if redundant > 0 and total_dof > baseline: return "mixed" elif redundant > 0: return "overconstrained" elif total_dof > baseline: return "underconstrained" elif total_dof == baseline: return "well-constrained" else: return "overconstrained"