Files
solver/kindred_solver/solver.py
forbes-0023 9dad25e947 feat(solver): assembly integration — diagnose, drag protocol, system extraction (phase 5)
- Extract _build_system() from solve() to enable reuse by diagnose()
- Add diagnose(ctx) method: runs find_overconstrained() unconditionally
- Add interactive drag protocol: pre_drag(), drag_step(), post_drag()
- Add _run_diagnostics() and _extract_placements() helpers
- Log warning when joint limits are present (not yet enforced)
- KindredSolver now implements all IKCSolver methods needed for
  full Assembly workbench integration
2026-02-20 23:32:51 -06:00

601 lines
18 KiB
Python

"""KindredSolver — IKCSolver implementation bridging KCSolve to the
expression-based Newton-Raphson solver."""
from __future__ import annotations
import kcsolve
from .bfgs import bfgs_solve
from .constraints import (
AngleConstraint,
BallConstraint,
CamConstraint,
CoincidentConstraint,
ConcentricConstraint,
ConstraintBase,
CylindricalConstraint,
DistanceCylSphConstraint,
DistancePointPointConstraint,
FixedConstraint,
GearConstraint,
LineInPlaneConstraint,
ParallelConstraint,
PerpendicularConstraint,
PlanarConstraint,
PointInPlaneConstraint,
PointOnLineConstraint,
RackPinionConstraint,
RevoluteConstraint,
ScrewConstraint,
SliderConstraint,
SlotConstraint,
TangentConstraint,
UniversalConstraint,
)
from .decompose import decompose, solve_decomposed
from .diagnostics import find_overconstrained
from .dof import count_dof
from .entities import RigidBody
from .newton import newton_solve
from .params import ParamTable
from .preference import (
apply_half_space_correction,
build_weight_vector,
compute_half_spaces,
)
from .prepass import single_equation_pass, substitution_pass
# Assemblies with fewer free bodies than this use the monolithic path.
_DECOMPOSE_THRESHOLD = 8
# All BaseJointKind values this solver can handle
_SUPPORTED = {
# Phase 1
kcsolve.BaseJointKind.Coincident,
kcsolve.BaseJointKind.DistancePointPoint,
kcsolve.BaseJointKind.Fixed,
# Phase 2: point constraints
kcsolve.BaseJointKind.PointOnLine,
kcsolve.BaseJointKind.PointInPlane,
# Phase 2: orientation
kcsolve.BaseJointKind.Parallel,
kcsolve.BaseJointKind.Perpendicular,
kcsolve.BaseJointKind.Angle,
# Phase 2: axis/surface
kcsolve.BaseJointKind.Concentric,
kcsolve.BaseJointKind.Tangent,
kcsolve.BaseJointKind.Planar,
kcsolve.BaseJointKind.LineInPlane,
# Phase 2: kinematic joints
kcsolve.BaseJointKind.Ball,
kcsolve.BaseJointKind.Revolute,
kcsolve.BaseJointKind.Cylindrical,
kcsolve.BaseJointKind.Slider,
kcsolve.BaseJointKind.Screw,
kcsolve.BaseJointKind.Universal,
# Phase 2: mechanical
kcsolve.BaseJointKind.Gear,
kcsolve.BaseJointKind.RackPinion,
}
class KindredSolver(kcsolve.IKCSolver):
"""Expression-based Newton-Raphson constraint solver."""
def __init__(self):
super().__init__()
self._drag_ctx = None
self._drag_parts = None
self._limits_warned = False
def name(self):
return "Kindred (Newton-Raphson)"
def supported_joints(self):
return list(_SUPPORTED)
# ── Static solve ────────────────────────────────────────────────
def solve(self, ctx):
system = _build_system(ctx)
# Warn once per solver instance if any constraints have limits
if not self._limits_warned:
for c in ctx.constraints:
if c.limits:
import logging
logging.getLogger(__name__).warning(
"Joint limits on '%s' ignored "
"(not yet supported by Kindred solver)",
c.id,
)
self._limits_warned = True
break
# Solution preference: half-spaces, weight vector
half_spaces = compute_half_spaces(
system.constraint_objs,
system.constraint_indices,
system.params,
)
weight_vec = build_weight_vector(system.params)
if half_spaces:
post_step_fn = lambda p: apply_half_space_correction(p, half_spaces)
else:
post_step_fn = None
# Pre-passes on full system
residuals = substitution_pass(system.all_residuals, system.params)
residuals = single_equation_pass(residuals, system.params)
# Solve (decomposed for large assemblies, monolithic for small)
n_free_bodies = sum(1 for b in system.bodies.values() if not b.grounded)
if n_free_bodies >= _DECOMPOSE_THRESHOLD:
grounded_ids = {pid for pid, b in system.bodies.items() if b.grounded}
clusters = decompose(ctx.constraints, grounded_ids)
if len(clusters) > 1:
converged = solve_decomposed(
clusters,
system.bodies,
system.constraint_objs,
system.constraint_indices,
system.params,
)
else:
converged = _monolithic_solve(
residuals,
system.params,
system.quat_groups,
post_step=post_step_fn,
weight_vector=weight_vec,
)
else:
converged = _monolithic_solve(
residuals,
system.params,
system.quat_groups,
post_step=post_step_fn,
weight_vector=weight_vec,
)
# DOF
dof = count_dof(residuals, system.params)
# Build result
result = kcsolve.SolveResult()
result.status = (
kcsolve.SolveStatus.Success if converged else kcsolve.SolveStatus.Failed
)
result.dof = dof
# Diagnostics on failure
if not converged:
result.diagnostics = _run_diagnostics(
residuals,
system.params,
system.residual_ranges,
ctx,
)
result.placements = _extract_placements(system.params, system.bodies)
return result
# ── Incremental update ──────────────────────────────────────────
# The base class default (delegates to solve()) is correct here:
# solve() uses current placements as initial guess, so small
# parameter changes converge quickly without special handling.
# ── Interactive drag ────────────────────────────────────────────
def pre_drag(self, ctx, drag_parts):
self._drag_ctx = ctx
self._drag_parts = set(drag_parts)
return self.solve(ctx)
def drag_step(self, drag_placements):
ctx = self._drag_ctx
if ctx is None:
return kcsolve.SolveResult()
for pr in drag_placements:
for part in ctx.parts:
if part.id == pr.id:
part.placement = pr.placement
break
return self.solve(ctx)
def post_drag(self):
self._drag_ctx = None
self._drag_parts = None
# ── Diagnostics ─────────────────────────────────────────────────
def diagnose(self, ctx):
system = _build_system(ctx)
residuals = substitution_pass(system.all_residuals, system.params)
return _run_diagnostics(
residuals,
system.params,
system.residual_ranges,
ctx,
)
def is_deterministic(self):
return True
class _System:
"""Intermediate representation of a built constraint system."""
__slots__ = (
"params",
"bodies",
"constraint_objs",
"constraint_indices",
"all_residuals",
"residual_ranges",
"quat_groups",
)
def _build_system(ctx):
"""Build the solver's internal representation from a SolveContext.
Returns a _System with params, bodies, constraint objects,
residuals, residual-to-constraint mapping, and quaternion groups.
"""
system = _System()
params = ParamTable()
bodies = {} # part_id -> RigidBody
# 1. Build entities from parts
for part in ctx.parts:
pos = tuple(part.placement.position)
quat = tuple(part.placement.quaternion) # (w, x, y, z)
body = RigidBody(
part.id,
params,
position=pos,
quaternion=quat,
grounded=part.grounded,
)
bodies[part.id] = body
# 2. Build constraint residuals (track index mapping for decomposition)
all_residuals = []
constraint_objs = []
constraint_indices = [] # parallel to constraint_objs: index in ctx.constraints
for idx, c in enumerate(ctx.constraints):
if not c.activated:
continue
body_i = bodies.get(c.part_i)
body_j = bodies.get(c.part_j)
if body_i is None or body_j is None:
continue
marker_i_pos = tuple(c.marker_i.position)
marker_j_pos = tuple(c.marker_j.position)
obj = _build_constraint(
c.type,
body_i,
marker_i_pos,
body_j,
marker_j_pos,
c.marker_i,
c.marker_j,
c.params,
)
if obj is None:
continue
constraint_objs.append(obj)
constraint_indices.append(idx)
all_residuals.extend(obj.residuals())
# 3. Build residual-to-constraint mapping
residual_ranges = [] # (start_row, end_row, constraint_idx)
row = 0
for i, obj in enumerate(constraint_objs):
n = len(obj.residuals())
residual_ranges.append((row, row + n, constraint_indices[i]))
row += n
# 4. Add quaternion normalization residuals for non-grounded bodies
quat_groups = []
for body in bodies.values():
if not body.grounded:
all_residuals.append(body.quat_norm_residual())
quat_groups.append(body.quat_param_names())
system.params = params
system.bodies = bodies
system.constraint_objs = constraint_objs
system.constraint_indices = constraint_indices
system.all_residuals = all_residuals
system.residual_ranges = residual_ranges
system.quat_groups = quat_groups
return system
def _run_diagnostics(residuals, params, residual_ranges, ctx):
"""Run overconstrained detection and return kcsolve diagnostics."""
diagnostics = []
if not hasattr(kcsolve, "ConstraintDiagnostic"):
return diagnostics
diags = find_overconstrained(residuals, params, residual_ranges)
for d in diags:
cd = kcsolve.ConstraintDiagnostic()
cd.constraint_id = ctx.constraints[d.constraint_index].id
cd.kind = (
kcsolve.DiagnosticKind.Redundant
if d.kind == "redundant"
else kcsolve.DiagnosticKind.Conflicting
)
cd.detail = d.detail
diagnostics.append(cd)
return diagnostics
def _extract_placements(params, bodies):
"""Extract solved placements from the parameter table."""
env = params.get_env()
placements = []
for body in bodies.values():
if body.grounded:
continue
pr = kcsolve.SolveResult.PartResult()
pr.id = body.part_id
pr.placement = kcsolve.Transform()
pr.placement.position = list(body.extract_position(env))
pr.placement.quaternion = list(body.extract_quaternion(env))
placements.append(pr)
return placements
def _monolithic_solve(
all_residuals, params, quat_groups, post_step=None, weight_vector=None
):
"""Newton-Raphson solve with BFGS fallback on the full system."""
converged = newton_solve(
all_residuals,
params,
quat_groups=quat_groups,
max_iter=100,
tol=1e-10,
post_step=post_step,
weight_vector=weight_vector,
)
if not converged:
converged = bfgs_solve(
all_residuals,
params,
quat_groups=quat_groups,
max_iter=200,
tol=1e-10,
weight_vector=weight_vector,
)
return converged
def _build_constraint(
kind,
body_i,
marker_i_pos,
body_j,
marker_j_pos,
marker_i,
marker_j,
c_params,
) -> ConstraintBase | None:
"""Create the appropriate constraint object from a BaseJointKind."""
marker_i_quat = tuple(marker_i.quaternion)
marker_j_quat = tuple(marker_j.quaternion)
# -- Phase 1 constraints --------------------------------------------------
if kind == kcsolve.BaseJointKind.Coincident:
return CoincidentConstraint(body_i, marker_i_pos, body_j, marker_j_pos)
if kind == kcsolve.BaseJointKind.DistancePointPoint:
distance = c_params[0] if c_params else 0.0
return DistancePointPointConstraint(
body_i,
marker_i_pos,
body_j,
marker_j_pos,
distance,
)
if kind == kcsolve.BaseJointKind.Fixed:
return FixedConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
# -- Phase 2: point constraints -------------------------------------------
if kind == kcsolve.BaseJointKind.PointOnLine:
return PointOnLineConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
if kind == kcsolve.BaseJointKind.PointInPlane:
offset = c_params[0] if c_params else 0.0
return PointInPlaneConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
offset=offset,
)
# -- Phase 2: orientation constraints -------------------------------------
if kind == kcsolve.BaseJointKind.Parallel:
return ParallelConstraint(body_i, marker_i_quat, body_j, marker_j_quat)
if kind == kcsolve.BaseJointKind.Perpendicular:
return PerpendicularConstraint(body_i, marker_i_quat, body_j, marker_j_quat)
if kind == kcsolve.BaseJointKind.Angle:
angle = c_params[0] if c_params else 0.0
return AngleConstraint(body_i, marker_i_quat, body_j, marker_j_quat, angle)
# -- Phase 2: axis/surface constraints ------------------------------------
if kind == kcsolve.BaseJointKind.Concentric:
distance = c_params[0] if c_params else 0.0
return ConcentricConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
distance=distance,
)
if kind == kcsolve.BaseJointKind.Tangent:
return TangentConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
if kind == kcsolve.BaseJointKind.Planar:
offset = c_params[0] if c_params else 0.0
return PlanarConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
offset=offset,
)
if kind == kcsolve.BaseJointKind.LineInPlane:
offset = c_params[0] if c_params else 0.0
return LineInPlaneConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
offset=offset,
)
# -- Phase 2: kinematic joints --------------------------------------------
if kind == kcsolve.BaseJointKind.Ball:
return BallConstraint(body_i, marker_i_pos, body_j, marker_j_pos)
if kind == kcsolve.BaseJointKind.Revolute:
return RevoluteConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
if kind == kcsolve.BaseJointKind.Cylindrical:
return CylindricalConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
if kind == kcsolve.BaseJointKind.Slider:
return SliderConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
if kind == kcsolve.BaseJointKind.Screw:
pitch = c_params[0] if c_params else 1.0
return ScrewConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
pitch=pitch,
)
if kind == kcsolve.BaseJointKind.Universal:
return UniversalConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
# -- Phase 2: mechanical constraints --------------------------------------
if kind == kcsolve.BaseJointKind.Gear:
radius_i = c_params[0] if len(c_params) > 0 else 1.0
radius_j = c_params[1] if len(c_params) > 1 else 1.0
return GearConstraint(
body_i,
marker_i_quat,
body_j,
marker_j_quat,
radius_i,
radius_j,
)
if kind == kcsolve.BaseJointKind.RackPinion:
pitch_radius = c_params[0] if c_params else 1.0
return RackPinionConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
pitch_radius=pitch_radius,
)
# -- Stubs (accepted but produce no residuals) ----------------------------
if kind == kcsolve.BaseJointKind.Cam:
return CamConstraint()
if kind == kcsolve.BaseJointKind.Slot:
return SlotConstraint()
if kind == kcsolve.BaseJointKind.DistanceCylSph:
return DistanceCylSphConstraint()
return None