feat(solver): assembly integration — diagnose, drag protocol, system extraction (phase 5)

- Extract _build_system() from solve() to enable reuse by diagnose()
- Add diagnose(ctx) method: runs find_overconstrained() unconditionally
- Add interactive drag protocol: pre_drag(), drag_step(), post_drag()
- Add _run_diagnostics() and _extract_placements() helpers
- Log warning when joint limits are present (not yet enforced)
- KindredSolver now implements all IKCSolver methods needed for
  full Assembly workbench integration
This commit is contained in:
forbes-0023
2026-02-20 23:32:51 -06:00
parent b4b8724ff1
commit 9dad25e947

View File

@@ -33,10 +33,16 @@ from .constraints import (
UniversalConstraint,
)
from .decompose import decompose, solve_decomposed
from .diagnostics import find_overconstrained
from .dof import count_dof
from .entities import RigidBody
from .newton import newton_solve
from .params import ParamTable
from .preference import (
apply_half_space_correction,
build_weight_vector,
compute_half_spaces,
)
from .prepass import single_equation_pass, substitution_pass
# Assemblies with fewer free bodies than this use the monolithic path.
@@ -76,124 +82,282 @@ _SUPPORTED = {
class KindredSolver(kcsolve.IKCSolver):
"""Expression-based Newton-Raphson constraint solver."""
def __init__(self):
super().__init__()
self._drag_ctx = None
self._drag_parts = None
self._limits_warned = False
def name(self):
return "Kindred (Newton-Raphson)"
def supported_joints(self):
return list(_SUPPORTED)
# ── Static solve ────────────────────────────────────────────────
def solve(self, ctx):
params = ParamTable()
bodies = {} # part_id -> RigidBody
system = _build_system(ctx)
# 1. Build entities from parts
for part in ctx.parts:
pos = tuple(part.placement.position)
quat = tuple(part.placement.quaternion) # (w, x, y, z)
body = RigidBody(
part.id,
params,
position=pos,
quaternion=quat,
grounded=part.grounded,
)
bodies[part.id] = body
# Warn once per solver instance if any constraints have limits
if not self._limits_warned:
for c in ctx.constraints:
if c.limits:
import logging
# 2. Build constraint residuals (track index mapping for decomposition)
all_residuals = []
constraint_objs = []
constraint_indices = [] # parallel to constraint_objs: index in ctx.constraints
logging.getLogger(__name__).warning(
"Joint limits on '%s' ignored "
"(not yet supported by Kindred solver)",
c.id,
)
self._limits_warned = True
break
for idx, c in enumerate(ctx.constraints):
if not c.activated:
continue
body_i = bodies.get(c.part_i)
body_j = bodies.get(c.part_j)
if body_i is None or body_j is None:
continue
# Solution preference: half-spaces, weight vector
half_spaces = compute_half_spaces(
system.constraint_objs,
system.constraint_indices,
system.params,
)
weight_vec = build_weight_vector(system.params)
marker_i_pos = tuple(c.marker_i.position)
marker_j_pos = tuple(c.marker_j.position)
if half_spaces:
post_step_fn = lambda p: apply_half_space_correction(p, half_spaces)
else:
post_step_fn = None
obj = _build_constraint(
c.type,
body_i,
marker_i_pos,
body_j,
marker_j_pos,
c.marker_i,
c.marker_j,
c.params,
)
if obj is None:
continue
constraint_objs.append(obj)
constraint_indices.append(idx)
all_residuals.extend(obj.residuals())
# Pre-passes on full system
residuals = substitution_pass(system.all_residuals, system.params)
residuals = single_equation_pass(residuals, system.params)
# 3. Add quaternion normalization residuals for non-grounded bodies
quat_groups = []
for body in bodies.values():
if not body.grounded:
all_residuals.append(body.quat_norm_residual())
quat_groups.append(body.quat_param_names())
# 4. Pre-passes on full system
all_residuals = substitution_pass(all_residuals, params)
all_residuals = single_equation_pass(all_residuals, params)
# 5. Solve (decomposed for large assemblies, monolithic for small)
n_free_bodies = sum(1 for b in bodies.values() if not b.grounded)
# Solve (decomposed for large assemblies, monolithic for small)
n_free_bodies = sum(1 for b in system.bodies.values() if not b.grounded)
if n_free_bodies >= _DECOMPOSE_THRESHOLD:
grounded_ids = {pid for pid, b in bodies.items() if b.grounded}
grounded_ids = {pid for pid, b in system.bodies.items() if b.grounded}
clusters = decompose(ctx.constraints, grounded_ids)
if len(clusters) > 1:
converged = solve_decomposed(
clusters,
bodies,
constraint_objs,
constraint_indices,
params,
system.bodies,
system.constraint_objs,
system.constraint_indices,
system.params,
)
else:
converged = _monolithic_solve(
all_residuals,
params,
quat_groups,
residuals,
system.params,
system.quat_groups,
post_step=post_step_fn,
weight_vector=weight_vec,
)
else:
converged = _monolithic_solve(all_residuals, params, quat_groups)
converged = _monolithic_solve(
residuals,
system.params,
system.quat_groups,
post_step=post_step_fn,
weight_vector=weight_vec,
)
# 6. DOF
dof = count_dof(all_residuals, params)
# DOF
dof = count_dof(residuals, system.params)
# 7. Build result
# Build result
result = kcsolve.SolveResult()
result.status = (
kcsolve.SolveStatus.Success if converged else kcsolve.SolveStatus.Failed
)
result.dof = dof
env = params.get_env()
placements = []
for body in bodies.values():
if body.grounded:
continue
pr = kcsolve.SolveResult.PartResult()
pr.id = body.part_id
pr.placement = kcsolve.Transform()
pr.placement.position = list(body.extract_position(env))
pr.placement.quaternion = list(body.extract_quaternion(env))
placements.append(pr)
# Diagnostics on failure
if not converged:
result.diagnostics = _run_diagnostics(
residuals,
system.params,
system.residual_ranges,
ctx,
)
result.placements = placements
result.placements = _extract_placements(system.params, system.bodies)
return result
# ── Incremental update ──────────────────────────────────────────
# The base class default (delegates to solve()) is correct here:
# solve() uses current placements as initial guess, so small
# parameter changes converge quickly without special handling.
# ── Interactive drag ────────────────────────────────────────────
def pre_drag(self, ctx, drag_parts):
self._drag_ctx = ctx
self._drag_parts = set(drag_parts)
return self.solve(ctx)
def drag_step(self, drag_placements):
ctx = self._drag_ctx
if ctx is None:
return kcsolve.SolveResult()
for pr in drag_placements:
for part in ctx.parts:
if part.id == pr.id:
part.placement = pr.placement
break
return self.solve(ctx)
def post_drag(self):
self._drag_ctx = None
self._drag_parts = None
# ── Diagnostics ─────────────────────────────────────────────────
def diagnose(self, ctx):
system = _build_system(ctx)
residuals = substitution_pass(system.all_residuals, system.params)
return _run_diagnostics(
residuals,
system.params,
system.residual_ranges,
ctx,
)
def is_deterministic(self):
return True
def _monolithic_solve(all_residuals, params, quat_groups):
class _System:
"""Intermediate representation of a built constraint system."""
__slots__ = (
"params",
"bodies",
"constraint_objs",
"constraint_indices",
"all_residuals",
"residual_ranges",
"quat_groups",
)
def _build_system(ctx):
"""Build the solver's internal representation from a SolveContext.
Returns a _System with params, bodies, constraint objects,
residuals, residual-to-constraint mapping, and quaternion groups.
"""
system = _System()
params = ParamTable()
bodies = {} # part_id -> RigidBody
# 1. Build entities from parts
for part in ctx.parts:
pos = tuple(part.placement.position)
quat = tuple(part.placement.quaternion) # (w, x, y, z)
body = RigidBody(
part.id,
params,
position=pos,
quaternion=quat,
grounded=part.grounded,
)
bodies[part.id] = body
# 2. Build constraint residuals (track index mapping for decomposition)
all_residuals = []
constraint_objs = []
constraint_indices = [] # parallel to constraint_objs: index in ctx.constraints
for idx, c in enumerate(ctx.constraints):
if not c.activated:
continue
body_i = bodies.get(c.part_i)
body_j = bodies.get(c.part_j)
if body_i is None or body_j is None:
continue
marker_i_pos = tuple(c.marker_i.position)
marker_j_pos = tuple(c.marker_j.position)
obj = _build_constraint(
c.type,
body_i,
marker_i_pos,
body_j,
marker_j_pos,
c.marker_i,
c.marker_j,
c.params,
)
if obj is None:
continue
constraint_objs.append(obj)
constraint_indices.append(idx)
all_residuals.extend(obj.residuals())
# 3. Build residual-to-constraint mapping
residual_ranges = [] # (start_row, end_row, constraint_idx)
row = 0
for i, obj in enumerate(constraint_objs):
n = len(obj.residuals())
residual_ranges.append((row, row + n, constraint_indices[i]))
row += n
# 4. Add quaternion normalization residuals for non-grounded bodies
quat_groups = []
for body in bodies.values():
if not body.grounded:
all_residuals.append(body.quat_norm_residual())
quat_groups.append(body.quat_param_names())
system.params = params
system.bodies = bodies
system.constraint_objs = constraint_objs
system.constraint_indices = constraint_indices
system.all_residuals = all_residuals
system.residual_ranges = residual_ranges
system.quat_groups = quat_groups
return system
def _run_diagnostics(residuals, params, residual_ranges, ctx):
"""Run overconstrained detection and return kcsolve diagnostics."""
diagnostics = []
if not hasattr(kcsolve, "ConstraintDiagnostic"):
return diagnostics
diags = find_overconstrained(residuals, params, residual_ranges)
for d in diags:
cd = kcsolve.ConstraintDiagnostic()
cd.constraint_id = ctx.constraints[d.constraint_index].id
cd.kind = (
kcsolve.DiagnosticKind.Redundant
if d.kind == "redundant"
else kcsolve.DiagnosticKind.Conflicting
)
cd.detail = d.detail
diagnostics.append(cd)
return diagnostics
def _extract_placements(params, bodies):
"""Extract solved placements from the parameter table."""
env = params.get_env()
placements = []
for body in bodies.values():
if body.grounded:
continue
pr = kcsolve.SolveResult.PartResult()
pr.id = body.part_id
pr.placement = kcsolve.Transform()
pr.placement.position = list(body.extract_position(env))
pr.placement.quaternion = list(body.extract_quaternion(env))
placements.append(pr)
return placements
def _monolithic_solve(
all_residuals, params, quat_groups, post_step=None, weight_vector=None
):
"""Newton-Raphson solve with BFGS fallback on the full system."""
converged = newton_solve(
all_residuals,
@@ -201,6 +365,8 @@ def _monolithic_solve(all_residuals, params, quat_groups):
quat_groups=quat_groups,
max_iter=100,
tol=1e-10,
post_step=post_step,
weight_vector=weight_vector,
)
if not converged:
converged = bfgs_solve(
@@ -209,6 +375,7 @@ def _monolithic_solve(all_residuals, params, quat_groups):
quat_groups=quat_groups,
max_iter=200,
tol=1e-10,
weight_vector=weight_vector,
)
return converged