feat(solver): graph decomposition for cluster-by-cluster solving (phase 3)
Add a Python decomposition layer using NetworkX that partitions the constraint graph into biconnected components (rigid clusters), orders them via a block-cut tree, and solves each cluster independently. Articulation-point bodies propagate as boundary conditions between clusters. New module kindred_solver/decompose.py: - DOF table mapping BaseJointKind to residual counts - Constraint graph construction (nx.MultiGraph) - Biconnected component detection + articulation points - Block-cut tree solve ordering (root-first from grounded cluster) - Cluster-by-cluster solver with boundary body fix/unfix cycling - Pebble game integration for per-cluster rigidity classification Changes to existing modules: - params.py: add unfix() for boundary body cycling - solver.py: extract _monolithic_solve(), add decomposition branch for assemblies with >= 8 free bodies Performance: for k clusters of ~n/k params each, total cost drops from O(n^3) to O(n^3/k^2). 220 tests passing (up from 207).
This commit is contained in:
661
kindred_solver/decompose.py
Normal file
661
kindred_solver/decompose.py
Normal file
@@ -0,0 +1,661 @@
|
||||
"""Graph decomposition for cluster-by-cluster constraint solving.
|
||||
|
||||
Builds a constraint graph from the SolveContext, decomposes it into
|
||||
biconnected components (rigid clusters), orders them via a block-cut
|
||||
tree, and solves each cluster independently. Articulation-point bodies
|
||||
are temporarily fixed when solving adjacent clusters so their solved
|
||||
values propagate as boundary conditions.
|
||||
|
||||
Requires: networkx
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
import sys
|
||||
import types as stdlib_types
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from .bfgs import bfgs_solve
|
||||
from .newton import newton_solve
|
||||
from .prepass import substitution_pass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .constraints import ConstraintBase
|
||||
from .entities import RigidBody
|
||||
from .params import ParamTable
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DOF table: BaseJointKind → number of residuals (= DOF removed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Imported lazily to avoid hard kcsolve dependency in tests.
|
||||
# Use residual_count() accessor instead of this dict directly.
|
||||
_RESIDUAL_COUNT: dict[str, int] | None = None
|
||||
|
||||
|
||||
def _ensure_residual_count() -> dict:
|
||||
"""Build the residual count table on first use."""
|
||||
global _RESIDUAL_COUNT
|
||||
if _RESIDUAL_COUNT is not None:
|
||||
return _RESIDUAL_COUNT
|
||||
|
||||
import kcsolve
|
||||
|
||||
_RESIDUAL_COUNT = {
|
||||
kcsolve.BaseJointKind.Fixed: 6,
|
||||
kcsolve.BaseJointKind.Coincident: 3,
|
||||
kcsolve.BaseJointKind.Ball: 3,
|
||||
kcsolve.BaseJointKind.Revolute: 5,
|
||||
kcsolve.BaseJointKind.Cylindrical: 4,
|
||||
kcsolve.BaseJointKind.Slider: 5,
|
||||
kcsolve.BaseJointKind.Screw: 5,
|
||||
kcsolve.BaseJointKind.Universal: 4,
|
||||
kcsolve.BaseJointKind.Parallel: 2,
|
||||
kcsolve.BaseJointKind.Perpendicular: 1,
|
||||
kcsolve.BaseJointKind.Angle: 1,
|
||||
kcsolve.BaseJointKind.Concentric: 4,
|
||||
kcsolve.BaseJointKind.Tangent: 1,
|
||||
kcsolve.BaseJointKind.Planar: 3,
|
||||
kcsolve.BaseJointKind.LineInPlane: 2,
|
||||
kcsolve.BaseJointKind.PointOnLine: 2,
|
||||
kcsolve.BaseJointKind.PointInPlane: 1,
|
||||
kcsolve.BaseJointKind.DistancePointPoint: 1,
|
||||
kcsolve.BaseJointKind.Gear: 1,
|
||||
kcsolve.BaseJointKind.RackPinion: 1,
|
||||
kcsolve.BaseJointKind.Cam: 0,
|
||||
kcsolve.BaseJointKind.Slot: 0,
|
||||
kcsolve.BaseJointKind.DistanceCylSph: 0,
|
||||
}
|
||||
return _RESIDUAL_COUNT
|
||||
|
||||
|
||||
def residual_count(kind) -> int:
|
||||
"""Number of residuals a constraint type produces."""
|
||||
return _ensure_residual_count().get(kind, 0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Standalone residual-count table (no kcsolve dependency, string-keyed)
|
||||
# Used by tests that don't have kcsolve available.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_RESIDUAL_COUNT_BY_NAME: dict[str, int] = {
|
||||
"Fixed": 6,
|
||||
"Coincident": 3,
|
||||
"Ball": 3,
|
||||
"Revolute": 5,
|
||||
"Cylindrical": 4,
|
||||
"Slider": 5,
|
||||
"Screw": 5,
|
||||
"Universal": 4,
|
||||
"Parallel": 2,
|
||||
"Perpendicular": 1,
|
||||
"Angle": 1,
|
||||
"Concentric": 4,
|
||||
"Tangent": 1,
|
||||
"Planar": 3,
|
||||
"LineInPlane": 2,
|
||||
"PointOnLine": 2,
|
||||
"PointInPlane": 1,
|
||||
"DistancePointPoint": 1,
|
||||
"Gear": 1,
|
||||
"RackPinion": 1,
|
||||
"Cam": 0,
|
||||
"Slot": 0,
|
||||
"DistanceCylSph": 0,
|
||||
}
|
||||
|
||||
|
||||
def residual_count_by_name(kind_name: str) -> int:
|
||||
"""Number of residuals by constraint type name (no kcsolve needed)."""
|
||||
return _RESIDUAL_COUNT_BY_NAME.get(kind_name, 0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data structures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class SolveCluster:
|
||||
"""A cluster of bodies to solve together."""
|
||||
|
||||
bodies: set[str] # Body IDs in this cluster
|
||||
constraint_indices: list[int] # Indices into the constraint list
|
||||
boundary_bodies: set[str] # Articulation points shared with other clusters
|
||||
has_ground: bool # Whether any body in the cluster is grounded
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Graph construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_constraint_graph(
|
||||
constraints: list,
|
||||
grounded_bodies: set[str],
|
||||
) -> nx.MultiGraph:
|
||||
"""Build a body-level constraint multigraph.
|
||||
|
||||
Nodes: part_id strings (one per body referenced by constraints).
|
||||
Edges: one per active constraint with attributes:
|
||||
- constraint_index: position in the constraints list
|
||||
- weight: number of residuals
|
||||
|
||||
Grounded bodies are tagged with ``grounded=True``.
|
||||
Constraints with 0 residuals (stubs) are excluded.
|
||||
"""
|
||||
G = nx.MultiGraph()
|
||||
|
||||
for idx, c in enumerate(constraints):
|
||||
if not c.activated:
|
||||
continue
|
||||
weight = residual_count(c.type)
|
||||
if weight == 0:
|
||||
continue
|
||||
part_i = c.part_i
|
||||
part_j = c.part_j
|
||||
# Ensure nodes exist
|
||||
if part_i not in G:
|
||||
G.add_node(part_i, grounded=(part_i in grounded_bodies))
|
||||
if part_j not in G:
|
||||
G.add_node(part_j, grounded=(part_j in grounded_bodies))
|
||||
# Store kind_name for pebble game integration
|
||||
kind_name = c.type.name if hasattr(c.type, "name") else str(c.type)
|
||||
G.add_edge(
|
||||
part_i, part_j, constraint_index=idx, weight=weight, kind_name=kind_name
|
||||
)
|
||||
|
||||
return G
|
||||
|
||||
|
||||
def build_constraint_graph_simple(
|
||||
edges: list[tuple[str, str, str, int]],
|
||||
grounded: set[str] | None = None,
|
||||
) -> nx.MultiGraph:
|
||||
"""Build a constraint graph from simple edge tuples (for testing).
|
||||
|
||||
Each edge is ``(body_i, body_j, kind_name, constraint_index)``.
|
||||
"""
|
||||
grounded = grounded or set()
|
||||
G = nx.MultiGraph()
|
||||
for body_i, body_j, kind_name, idx in edges:
|
||||
weight = residual_count_by_name(kind_name)
|
||||
if weight == 0:
|
||||
continue
|
||||
if body_i not in G:
|
||||
G.add_node(body_i, grounded=(body_i in grounded))
|
||||
if body_j not in G:
|
||||
G.add_node(body_j, grounded=(body_j in grounded))
|
||||
G.add_edge(
|
||||
body_i, body_j, constraint_index=idx, weight=weight, kind_name=kind_name
|
||||
)
|
||||
return G
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decomposition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def find_clusters(
|
||||
G: nx.MultiGraph,
|
||||
) -> tuple[list[set[str]], set[str]]:
|
||||
"""Find biconnected components and articulation points.
|
||||
|
||||
Returns:
|
||||
clusters: list of body-ID sets (one per biconnected component)
|
||||
articulation_points: body-IDs shared between clusters
|
||||
"""
|
||||
# biconnected_components requires a simple Graph
|
||||
simple = nx.Graph(G)
|
||||
clusters = [set(c) for c in nx.biconnected_components(simple)]
|
||||
artic = set(nx.articulation_points(simple))
|
||||
return clusters, artic
|
||||
|
||||
|
||||
def build_solve_order(
|
||||
G: nx.MultiGraph,
|
||||
clusters: list[set[str]],
|
||||
articulation_points: set[str],
|
||||
grounded_bodies: set[str],
|
||||
) -> list[SolveCluster]:
|
||||
"""Order clusters for solving via the block-cut tree.
|
||||
|
||||
Builds the block-cut tree (bipartite graph of clusters and
|
||||
articulation points), roots it at a grounded cluster, and returns
|
||||
clusters in root-to-leaf order (grounded first, outward to leaves).
|
||||
This ensures boundary bodies are solved before clusters that
|
||||
depend on them.
|
||||
"""
|
||||
if not clusters:
|
||||
return []
|
||||
|
||||
# Single cluster — no ordering needed
|
||||
if len(clusters) == 1:
|
||||
bodies = clusters[0]
|
||||
indices = _constraints_for_bodies(G, bodies)
|
||||
has_ground = bool(bodies & grounded_bodies)
|
||||
return [
|
||||
SolveCluster(
|
||||
bodies=bodies,
|
||||
constraint_indices=indices,
|
||||
boundary_bodies=set(),
|
||||
has_ground=has_ground,
|
||||
)
|
||||
]
|
||||
|
||||
# Build block-cut tree
|
||||
# Nodes: ("C", i) for cluster i, ("A", body_id) for articulation points
|
||||
bct = nx.Graph()
|
||||
for i, cluster in enumerate(clusters):
|
||||
bct.add_node(("C", i))
|
||||
for ap in articulation_points:
|
||||
if ap in cluster:
|
||||
bct.add_edge(("C", i), ("A", ap))
|
||||
|
||||
# Find root: prefer a cluster containing a grounded body
|
||||
root = ("C", 0)
|
||||
for i, cluster in enumerate(clusters):
|
||||
if cluster & grounded_bodies:
|
||||
root = ("C", i)
|
||||
break
|
||||
|
||||
# BFS from root: grounded cluster first, outward to leaves
|
||||
visited = set()
|
||||
order = []
|
||||
queue = deque([root])
|
||||
visited.add(root)
|
||||
while queue:
|
||||
node = queue.popleft()
|
||||
if node[0] == "C":
|
||||
order.append(node[1])
|
||||
for neighbor in bct.neighbors(node):
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append(neighbor)
|
||||
|
||||
# Build SolveCluster objects
|
||||
solve_clusters = []
|
||||
for i in order:
|
||||
bodies = clusters[i]
|
||||
indices = _constraints_for_bodies(G, bodies)
|
||||
boundary = bodies & articulation_points
|
||||
has_ground = bool(bodies & grounded_bodies)
|
||||
solve_clusters.append(
|
||||
SolveCluster(
|
||||
bodies=bodies,
|
||||
constraint_indices=indices,
|
||||
boundary_bodies=boundary,
|
||||
has_ground=has_ground,
|
||||
)
|
||||
)
|
||||
|
||||
return solve_clusters
|
||||
|
||||
|
||||
def _constraints_for_bodies(G: nx.MultiGraph, bodies: set[str]) -> list[int]:
|
||||
"""Collect constraint indices for edges where both endpoints are in bodies."""
|
||||
indices = []
|
||||
seen = set()
|
||||
for u, v, data in G.edges(data=True):
|
||||
idx = data["constraint_index"]
|
||||
if idx in seen:
|
||||
continue
|
||||
if u in bodies and v in bodies:
|
||||
seen.add(idx)
|
||||
indices.append(idx)
|
||||
return sorted(indices)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level decompose entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def decompose(
|
||||
constraints: list,
|
||||
grounded_bodies: set[str],
|
||||
) -> list[SolveCluster]:
|
||||
"""Full decomposition pipeline: graph → clusters → solve order.
|
||||
|
||||
Returns a list of SolveCluster in solve order (leaves first).
|
||||
If the system is a single cluster, returns a 1-element list.
|
||||
"""
|
||||
G = build_constraint_graph(constraints, grounded_bodies)
|
||||
|
||||
# Handle disconnected sub-assemblies
|
||||
all_clusters = []
|
||||
for component_nodes in nx.connected_components(G):
|
||||
sub = G.subgraph(component_nodes).copy()
|
||||
clusters, artic = find_clusters(sub)
|
||||
|
||||
if len(clusters) <= 1:
|
||||
# Single cluster in this component
|
||||
bodies = component_nodes if not clusters else clusters[0]
|
||||
indices = _constraints_for_bodies(sub, bodies)
|
||||
has_ground = bool(bodies & grounded_bodies)
|
||||
all_clusters.append(
|
||||
SolveCluster(
|
||||
bodies=set(bodies),
|
||||
constraint_indices=indices,
|
||||
boundary_bodies=set(),
|
||||
has_ground=has_ground,
|
||||
)
|
||||
)
|
||||
else:
|
||||
ordered = build_solve_order(sub, clusters, artic, grounded_bodies)
|
||||
all_clusters.extend(ordered)
|
||||
|
||||
return all_clusters
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cluster solver
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def solve_decomposed(
|
||||
clusters: list[SolveCluster],
|
||||
bodies: dict[str, "RigidBody"],
|
||||
constraint_objs: list["ConstraintBase"],
|
||||
constraint_indices_map: list[int],
|
||||
params: "ParamTable",
|
||||
) -> bool:
|
||||
"""Solve clusters in order, fixing boundary bodies between solves.
|
||||
|
||||
Args:
|
||||
clusters: SolveCluster list in solve order (from decompose()).
|
||||
bodies: part_id → RigidBody mapping.
|
||||
constraint_objs: constraint objects (parallel to constraint_indices_map).
|
||||
constraint_indices_map: for each constraint_obj, its index in ctx.constraints.
|
||||
params: shared ParamTable.
|
||||
|
||||
Returns True if all clusters converged.
|
||||
"""
|
||||
# Build reverse map: constraint_index → position in constraint_objs list
|
||||
idx_to_obj: dict[int, "ConstraintBase"] = {}
|
||||
for pos, ci in enumerate(constraint_indices_map):
|
||||
idx_to_obj[ci] = constraint_objs[pos]
|
||||
|
||||
solved_bodies: set[str] = set()
|
||||
all_converged = True
|
||||
|
||||
for cluster in clusters:
|
||||
# 1. Fix boundary bodies that were already solved
|
||||
fixed_boundary_params: list[str] = []
|
||||
for body_id in cluster.boundary_bodies:
|
||||
if body_id in solved_bodies:
|
||||
body = bodies[body_id]
|
||||
for pname in body._param_names:
|
||||
if not params.is_fixed(pname):
|
||||
params.fix(pname)
|
||||
fixed_boundary_params.append(pname)
|
||||
|
||||
# 2. Collect residuals for this cluster
|
||||
cluster_residuals = []
|
||||
for ci in cluster.constraint_indices:
|
||||
obj = idx_to_obj.get(ci)
|
||||
if obj is not None:
|
||||
cluster_residuals.extend(obj.residuals())
|
||||
|
||||
# 3. Add quat norm residuals for free, non-grounded bodies in this cluster
|
||||
quat_groups = []
|
||||
for body_id in cluster.bodies:
|
||||
body = bodies[body_id]
|
||||
if body.grounded:
|
||||
continue
|
||||
if body_id in cluster.boundary_bodies and body_id in solved_bodies:
|
||||
continue # Already fixed as boundary
|
||||
cluster_residuals.append(body.quat_norm_residual())
|
||||
quat_groups.append(body.quat_param_names())
|
||||
|
||||
# 4. Substitution pass (compiles fixed boundary params to constants)
|
||||
cluster_residuals = substitution_pass(cluster_residuals, params)
|
||||
|
||||
# 5. Newton solve (+ BFGS fallback)
|
||||
if cluster_residuals:
|
||||
converged = newton_solve(
|
||||
cluster_residuals,
|
||||
params,
|
||||
quat_groups=quat_groups,
|
||||
max_iter=100,
|
||||
tol=1e-10,
|
||||
)
|
||||
if not converged:
|
||||
converged = bfgs_solve(
|
||||
cluster_residuals,
|
||||
params,
|
||||
quat_groups=quat_groups,
|
||||
max_iter=200,
|
||||
tol=1e-10,
|
||||
)
|
||||
if not converged:
|
||||
all_converged = False
|
||||
|
||||
# 6. Mark this cluster's bodies as solved
|
||||
solved_bodies.update(cluster.bodies)
|
||||
|
||||
# 7. Unfix boundary params
|
||||
for pname in fixed_boundary_params:
|
||||
params.unfix(pname)
|
||||
|
||||
return all_converged
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pebble game integration (rigidity classification)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PEBBLE_MODULES_LOADED = False
|
||||
_PebbleGame3D = None
|
||||
_PebbleJointType = None
|
||||
_PebbleJoint = None
|
||||
|
||||
|
||||
def _load_pebble_modules():
|
||||
"""Lazily load PebbleGame3D and related types from GNN/solver/datagen/.
|
||||
|
||||
The GNN package has its own import structure (``from solver.datagen.types
|
||||
import ...``) that conflicts with the top-level module layout, so we
|
||||
register shim modules in ``sys.modules`` to make it work.
|
||||
"""
|
||||
global _PEBBLE_MODULES_LOADED, _PebbleGame3D, _PebbleJointType, _PebbleJoint
|
||||
if _PEBBLE_MODULES_LOADED:
|
||||
return
|
||||
|
||||
# Find GNN/solver/datagen relative to this package
|
||||
pkg_dir = Path(__file__).resolve().parent.parent # mods/solver/
|
||||
datagen_dir = pkg_dir / "GNN" / "solver" / "datagen"
|
||||
|
||||
if not datagen_dir.exists():
|
||||
log.warning("GNN/solver/datagen/ not found; pebble game unavailable")
|
||||
_PEBBLE_MODULES_LOADED = True
|
||||
return
|
||||
|
||||
# Register shim modules so ``from solver.datagen.types import ...`` works
|
||||
if "solver" not in sys.modules:
|
||||
sys.modules["solver"] = stdlib_types.ModuleType("solver")
|
||||
if "solver.datagen" not in sys.modules:
|
||||
dg = stdlib_types.ModuleType("solver.datagen")
|
||||
sys.modules["solver.datagen"] = dg
|
||||
sys.modules["solver"].datagen = dg # type: ignore[attr-defined]
|
||||
|
||||
# Load types.py
|
||||
types_path = datagen_dir / "types.py"
|
||||
spec_t = importlib.util.spec_from_file_location(
|
||||
"solver.datagen.types", str(types_path)
|
||||
)
|
||||
types_mod = importlib.util.module_from_spec(spec_t)
|
||||
sys.modules["solver.datagen.types"] = types_mod
|
||||
spec_t.loader.exec_module(types_mod)
|
||||
|
||||
# Load pebble_game.py
|
||||
pg_path = datagen_dir / "pebble_game.py"
|
||||
spec_p = importlib.util.spec_from_file_location(
|
||||
"solver.datagen.pebble_game", str(pg_path)
|
||||
)
|
||||
pg_mod = importlib.util.module_from_spec(spec_p)
|
||||
sys.modules["solver.datagen.pebble_game"] = pg_mod
|
||||
spec_p.loader.exec_module(pg_mod)
|
||||
|
||||
_PebbleGame3D = pg_mod.PebbleGame3D
|
||||
_PebbleJointType = types_mod.JointType
|
||||
_PebbleJoint = types_mod.Joint
|
||||
_PEBBLE_MODULES_LOADED = True
|
||||
|
||||
|
||||
# BaseJointKind name → PebbleGame JointType name.
|
||||
# Types not listed here use manual edge insertion with the residual count.
|
||||
_KIND_NAME_TO_PEBBLE_NAME: dict[str, str] = {
|
||||
"Fixed": "FIXED",
|
||||
"Coincident": "BALL", # Same DOF count (3)
|
||||
"Ball": "BALL",
|
||||
"Revolute": "REVOLUTE",
|
||||
"Cylindrical": "CYLINDRICAL",
|
||||
"Slider": "SLIDER",
|
||||
"Screw": "SCREW",
|
||||
"Universal": "UNIVERSAL",
|
||||
"Planar": "PLANAR",
|
||||
"Perpendicular": "PERPENDICULAR",
|
||||
"DistancePointPoint": "DISTANCE",
|
||||
}
|
||||
# Parallel: pebble game uses 3 DOF, but our solver uses 2.
|
||||
# We handle it with manual edge insertion.
|
||||
|
||||
# Types that need manual edge insertion (no direct JointType mapping,
|
||||
# or DOF mismatch like Parallel).
|
||||
_MANUAL_EDGE_TYPES: set[str] = {
|
||||
"Parallel", # 2 residuals, but JointType.PARALLEL = 3
|
||||
"Angle", # 1 residual, no JointType
|
||||
"Concentric", # 4 residuals, no JointType
|
||||
"Tangent", # 1 residual, no JointType
|
||||
"LineInPlane", # 2 residuals, no JointType
|
||||
"PointOnLine", # 2 residuals, no JointType
|
||||
"PointInPlane", # 1 residual, no JointType
|
||||
"Gear", # 1 residual, no JointType
|
||||
"RackPinion", # 1 residual, no JointType
|
||||
}
|
||||
|
||||
_GROUND_BODY_ID = -1
|
||||
|
||||
|
||||
def classify_cluster_rigidity(
|
||||
cluster: SolveCluster,
|
||||
constraint_graph: nx.MultiGraph,
|
||||
grounded_bodies: set[str],
|
||||
) -> str | None:
|
||||
"""Run pebble game on a cluster and return rigidity classification.
|
||||
|
||||
Returns one of: "well-constrained", "underconstrained",
|
||||
"overconstrained", "mixed", or None if pebble game unavailable.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
_load_pebble_modules()
|
||||
if _PebbleGame3D is None:
|
||||
return None
|
||||
|
||||
pg = _PebbleGame3D()
|
||||
|
||||
# Map string body IDs → integer IDs for pebble game
|
||||
body_list = sorted(cluster.bodies)
|
||||
body_to_int: dict[str, int] = {b: i for i, b in enumerate(body_list)}
|
||||
|
||||
for b in body_list:
|
||||
pg.add_body(body_to_int[b])
|
||||
|
||||
# Add virtual ground body if cluster has grounded bodies
|
||||
has_ground = bool(cluster.bodies & grounded_bodies)
|
||||
if has_ground:
|
||||
pg.add_body(_GROUND_BODY_ID)
|
||||
for b in cluster.bodies & grounded_bodies:
|
||||
ground_joint = _PebbleJoint(
|
||||
joint_id=-1,
|
||||
body_a=body_to_int[b],
|
||||
body_b=_GROUND_BODY_ID,
|
||||
joint_type=_PebbleJointType["FIXED"],
|
||||
anchor_a=np.zeros(3),
|
||||
anchor_b=np.zeros(3),
|
||||
)
|
||||
pg.add_joint(ground_joint)
|
||||
|
||||
# Add constraint edges
|
||||
joint_counter = 0
|
||||
zero = np.zeros(3)
|
||||
for u, v, data in constraint_graph.edges(data=True):
|
||||
if u not in cluster.bodies or v not in cluster.bodies:
|
||||
continue
|
||||
ci = data["constraint_index"]
|
||||
if ci not in cluster.constraint_indices:
|
||||
continue
|
||||
|
||||
# Determine the constraint kind name from the graph edge
|
||||
kind_name = data.get("kind_name", "")
|
||||
n_residuals = data.get("weight", 0)
|
||||
|
||||
if not kind_name or n_residuals == 0:
|
||||
continue
|
||||
|
||||
int_u = body_to_int[u]
|
||||
int_v = body_to_int[v]
|
||||
|
||||
pebble_name = _KIND_NAME_TO_PEBBLE_NAME.get(kind_name)
|
||||
if pebble_name and kind_name not in _MANUAL_EDGE_TYPES:
|
||||
# Direct JointType mapping
|
||||
jt = _PebbleJointType[pebble_name]
|
||||
joint = _PebbleJoint(
|
||||
joint_id=joint_counter,
|
||||
body_a=int_u,
|
||||
body_b=int_v,
|
||||
joint_type=jt,
|
||||
anchor_a=zero,
|
||||
anchor_b=zero,
|
||||
)
|
||||
pg.add_joint(joint)
|
||||
joint_counter += 1
|
||||
else:
|
||||
# Manual edge insertion: one DISTANCE edge per residual
|
||||
for _ in range(n_residuals):
|
||||
joint = _PebbleJoint(
|
||||
joint_id=joint_counter,
|
||||
body_a=int_u,
|
||||
body_b=int_v,
|
||||
joint_type=_PebbleJointType["DISTANCE"],
|
||||
anchor_a=zero,
|
||||
anchor_b=zero,
|
||||
)
|
||||
pg.add_joint(joint)
|
||||
joint_counter += 1
|
||||
|
||||
# Classify using raw pebble counts (adjusting for virtual ground)
|
||||
total_dof = pg.get_dof()
|
||||
redundant = pg.get_redundant_count()
|
||||
|
||||
# The virtual ground body contributes 6 pebbles that are never consumed.
|
||||
# Subtract them to get the effective DOF.
|
||||
if has_ground:
|
||||
total_dof -= 6 # virtual ground's unconstrained pebbles
|
||||
baseline = 0
|
||||
else:
|
||||
baseline = 6 # trivial rigid-body motion
|
||||
|
||||
if redundant > 0 and total_dof > baseline:
|
||||
return "mixed"
|
||||
elif redundant > 0:
|
||||
return "overconstrained"
|
||||
elif total_dof > baseline:
|
||||
return "underconstrained"
|
||||
elif total_dof == baseline:
|
||||
return "well-constrained"
|
||||
else:
|
||||
return "overconstrained"
|
||||
@@ -49,6 +49,13 @@ class ParamTable:
|
||||
if name in self._free_order:
|
||||
self._free_order.remove(name)
|
||||
|
||||
def unfix(self, name: str):
|
||||
"""Restore a fixed parameter to free status."""
|
||||
if name in self._fixed:
|
||||
self._fixed.discard(name)
|
||||
if name not in self._free_order:
|
||||
self._free_order.append(name)
|
||||
|
||||
def get_env(self) -> Dict[str, float]:
|
||||
"""Return a snapshot of all current values (for Expr.eval)."""
|
||||
return dict(self._values)
|
||||
|
||||
@@ -32,12 +32,16 @@ from .constraints import (
|
||||
TangentConstraint,
|
||||
UniversalConstraint,
|
||||
)
|
||||
from .decompose import decompose, solve_decomposed
|
||||
from .dof import count_dof
|
||||
from .entities import RigidBody
|
||||
from .newton import newton_solve
|
||||
from .params import ParamTable
|
||||
from .prepass import single_equation_pass, substitution_pass
|
||||
|
||||
# Assemblies with fewer free bodies than this use the monolithic path.
|
||||
_DECOMPOSE_THRESHOLD = 8
|
||||
|
||||
# All BaseJointKind values this solver can handle
|
||||
_SUPPORTED = {
|
||||
# Phase 1
|
||||
@@ -95,11 +99,12 @@ class KindredSolver(kcsolve.IKCSolver):
|
||||
)
|
||||
bodies[part.id] = body
|
||||
|
||||
# 2. Build constraint residuals
|
||||
# 2. Build constraint residuals (track index mapping for decomposition)
|
||||
all_residuals = []
|
||||
constraint_objs = []
|
||||
constraint_indices = [] # parallel to constraint_objs: index in ctx.constraints
|
||||
|
||||
for c in ctx.constraints:
|
||||
for idx, c in enumerate(ctx.constraints):
|
||||
if not c.activated:
|
||||
continue
|
||||
body_i = bodies.get(c.part_i)
|
||||
@@ -123,6 +128,7 @@ class KindredSolver(kcsolve.IKCSolver):
|
||||
if obj is None:
|
||||
continue
|
||||
constraint_objs.append(obj)
|
||||
constraint_indices.append(idx)
|
||||
all_residuals.extend(obj.residuals())
|
||||
|
||||
# 3. Add quaternion normalization residuals for non-grounded bodies
|
||||
@@ -132,26 +138,31 @@ class KindredSolver(kcsolve.IKCSolver):
|
||||
all_residuals.append(body.quat_norm_residual())
|
||||
quat_groups.append(body.quat_param_names())
|
||||
|
||||
# 4. Pre-passes
|
||||
# 4. Pre-passes on full system
|
||||
all_residuals = substitution_pass(all_residuals, params)
|
||||
all_residuals = single_equation_pass(all_residuals, params)
|
||||
|
||||
# 5. Newton-Raphson (with BFGS fallback)
|
||||
converged = newton_solve(
|
||||
all_residuals,
|
||||
params,
|
||||
quat_groups=quat_groups,
|
||||
max_iter=100,
|
||||
tol=1e-10,
|
||||
)
|
||||
if not converged:
|
||||
converged = bfgs_solve(
|
||||
all_residuals,
|
||||
params,
|
||||
quat_groups=quat_groups,
|
||||
max_iter=200,
|
||||
tol=1e-10,
|
||||
)
|
||||
# 5. Solve (decomposed for large assemblies, monolithic for small)
|
||||
n_free_bodies = sum(1 for b in bodies.values() if not b.grounded)
|
||||
if n_free_bodies >= _DECOMPOSE_THRESHOLD:
|
||||
grounded_ids = {pid for pid, b in bodies.items() if b.grounded}
|
||||
clusters = decompose(ctx.constraints, grounded_ids)
|
||||
if len(clusters) > 1:
|
||||
converged = solve_decomposed(
|
||||
clusters,
|
||||
bodies,
|
||||
constraint_objs,
|
||||
constraint_indices,
|
||||
params,
|
||||
)
|
||||
else:
|
||||
converged = _monolithic_solve(
|
||||
all_residuals,
|
||||
params,
|
||||
quat_groups,
|
||||
)
|
||||
else:
|
||||
converged = _monolithic_solve(all_residuals, params, quat_groups)
|
||||
|
||||
# 6. DOF
|
||||
dof = count_dof(all_residuals, params)
|
||||
@@ -182,6 +193,26 @@ class KindredSolver(kcsolve.IKCSolver):
|
||||
return True
|
||||
|
||||
|
||||
def _monolithic_solve(all_residuals, params, quat_groups):
|
||||
"""Newton-Raphson solve with BFGS fallback on the full system."""
|
||||
converged = newton_solve(
|
||||
all_residuals,
|
||||
params,
|
||||
quat_groups=quat_groups,
|
||||
max_iter=100,
|
||||
tol=1e-10,
|
||||
)
|
||||
if not converged:
|
||||
converged = bfgs_solve(
|
||||
all_residuals,
|
||||
params,
|
||||
quat_groups=quat_groups,
|
||||
max_iter=200,
|
||||
tol=1e-10,
|
||||
)
|
||||
return converged
|
||||
|
||||
|
||||
def _build_constraint(
|
||||
kind,
|
||||
body_i,
|
||||
|
||||
1052
tests/test_decompose.py
Normal file
1052
tests/test_decompose.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -65,3 +65,37 @@ class TestParamTable:
|
||||
pt.add("b", 0.0, fixed=True)
|
||||
pt.add("c", 0.0)
|
||||
assert pt.n_free() == 2
|
||||
|
||||
def test_unfix(self):
|
||||
pt = ParamTable()
|
||||
pt.add("a", 1.0)
|
||||
pt.add("b", 2.0)
|
||||
pt.fix("a")
|
||||
assert pt.is_fixed("a")
|
||||
assert "a" not in pt.free_names()
|
||||
|
||||
pt.unfix("a")
|
||||
assert not pt.is_fixed("a")
|
||||
assert "a" in pt.free_names()
|
||||
assert pt.n_free() == 2
|
||||
|
||||
def test_fix_unfix_roundtrip(self):
|
||||
"""Fix then unfix preserves value and makes param free again."""
|
||||
pt = ParamTable()
|
||||
pt.add("x", 5.0)
|
||||
pt.add("y", 3.0)
|
||||
pt.fix("x")
|
||||
pt.set_value("x", 10.0)
|
||||
pt.unfix("x")
|
||||
assert pt.get_value("x") == 10.0
|
||||
assert "x" in pt.free_names()
|
||||
# x moves to end of free list
|
||||
assert pt.free_names() == ["y", "x"]
|
||||
|
||||
def test_unfix_noop_if_already_free(self):
|
||||
"""Unfixing a free parameter is a no-op."""
|
||||
pt = ParamTable()
|
||||
pt.add("a", 1.0)
|
||||
pt.unfix("a")
|
||||
assert pt.free_names() == ["a"]
|
||||
assert pt.n_free() == 1
|
||||
|
||||
Reference in New Issue
Block a user