"""KindredSolver — IKCSolver implementation bridging KCSolve to the expression-based Newton-Raphson solver.""" from __future__ import annotations import logging import math import time import kcsolve log = logging.getLogger(__name__) from .bfgs import bfgs_solve from .constraints import ( AngleConstraint, BallConstraint, CamConstraint, CoincidentConstraint, ConcentricConstraint, ConstraintBase, CylindricalConstraint, DistanceCylSphConstraint, DistancePointPointConstraint, FixedConstraint, GearConstraint, LineInPlaneConstraint, ParallelConstraint, PerpendicularConstraint, PlanarConstraint, PointInPlaneConstraint, PointOnLineConstraint, RackPinionConstraint, RevoluteConstraint, ScrewConstraint, SliderConstraint, SlotConstraint, TangentConstraint, UniversalConstraint, ) from .decompose import decompose, solve_decomposed from .diagnostics import find_overconstrained from .dof import count_dof from .entities import RigidBody from .newton import newton_solve from .params import ParamTable from .preference import ( apply_half_space_correction, build_weight_vector, compute_half_spaces, ) from .prepass import single_equation_pass, substitution_pass # Assemblies with fewer free bodies than this use the monolithic path. _DECOMPOSE_THRESHOLD = 8 # All BaseJointKind values this solver can handle _SUPPORTED = { # Phase 1 kcsolve.BaseJointKind.Coincident, kcsolve.BaseJointKind.DistancePointPoint, kcsolve.BaseJointKind.Fixed, # Phase 2: point constraints kcsolve.BaseJointKind.PointOnLine, kcsolve.BaseJointKind.PointInPlane, # Phase 2: orientation kcsolve.BaseJointKind.Parallel, kcsolve.BaseJointKind.Perpendicular, kcsolve.BaseJointKind.Angle, # Phase 2: axis/surface kcsolve.BaseJointKind.Concentric, kcsolve.BaseJointKind.Tangent, kcsolve.BaseJointKind.Planar, kcsolve.BaseJointKind.LineInPlane, # Phase 2: kinematic joints kcsolve.BaseJointKind.Ball, kcsolve.BaseJointKind.Revolute, kcsolve.BaseJointKind.Cylindrical, kcsolve.BaseJointKind.Slider, kcsolve.BaseJointKind.Screw, kcsolve.BaseJointKind.Universal, # Phase 2: mechanical kcsolve.BaseJointKind.Gear, kcsolve.BaseJointKind.RackPinion, } class KindredSolver(kcsolve.IKCSolver): """Expression-based Newton-Raphson constraint solver.""" def __init__(self): super().__init__() self._drag_ctx = None self._drag_parts = None self._drag_cache = None self._limits_warned = False def name(self): return "Kindred (Newton-Raphson)" def supported_joints(self): return list(_SUPPORTED) # ── Static solve ──────────────────────────────────────────────── def solve(self, ctx): t0 = time.perf_counter() n_parts = len(ctx.parts) n_constraints = len(ctx.constraints) n_grounded = sum(1 for p in ctx.parts if p.grounded) log.info( "solve: %d parts (%d grounded), %d constraints", n_parts, n_grounded, n_constraints, ) system = _build_system(ctx) n_free_bodies = sum(1 for b in system.bodies.values() if not b.grounded) n_residuals = len(system.all_residuals) n_free_params = len(system.params.free_names()) log.info( "solve: system built — %d free bodies, %d residuals, %d free params", n_free_bodies, n_residuals, n_free_params, ) # Warn once per solver instance if any constraints have limits if not self._limits_warned: for c in ctx.constraints: if c.limits: log.warning( "Joint limits on '%s' ignored (not yet supported by Kindred solver)", c.id, ) self._limits_warned = True break # Solution preference: half-spaces, weight vector half_spaces = compute_half_spaces( system.constraint_objs, system.constraint_indices, system.params, ) if half_spaces: post_step_fn = lambda p: apply_half_space_correction(p, half_spaces) else: post_step_fn = None # Pre-passes on full system residuals = substitution_pass(system.all_residuals, system.params) residuals = single_equation_pass(residuals, system.params) # Build weight vector *after* pre-passes so its length matches the # remaining free parameters (single_equation_pass may fix some). weight_vec = build_weight_vector(system.params) # Solve (decomposed for large assemblies, monolithic for small) jac_exprs = None # may be populated by _monolithic_solve if n_free_bodies >= _DECOMPOSE_THRESHOLD: grounded_ids = {pid for pid, b in system.bodies.items() if b.grounded} clusters = decompose(ctx.constraints, grounded_ids) log.info( "solve: decomposed into %d cluster(s) (%d free bodies >= threshold %d)", len(clusters), n_free_bodies, _DECOMPOSE_THRESHOLD, ) if len(clusters) > 1: converged = solve_decomposed( clusters, system.bodies, system.constraint_objs, system.constraint_indices, system.params, ) else: converged, jac_exprs = _monolithic_solve( residuals, system.params, system.quat_groups, post_step=post_step_fn, weight_vector=weight_vec, ) else: log.debug( "solve: monolithic path (%d free bodies < threshold %d)", n_free_bodies, _DECOMPOSE_THRESHOLD, ) converged, jac_exprs = _monolithic_solve( residuals, system.params, system.quat_groups, post_step=post_step_fn, weight_vector=weight_vec, ) # DOF dof = count_dof(residuals, system.params, jac_exprs=jac_exprs) # Build result result = kcsolve.SolveResult() result.status = ( kcsolve.SolveStatus.Success if converged else kcsolve.SolveStatus.Failed ) result.dof = dof # Diagnostics on failure if not converged: result.diagnostics = _run_diagnostics( residuals, system.params, system.residual_ranges, ctx, jac_exprs=jac_exprs, ) result.placements = _extract_placements(system.params, system.bodies) elapsed = (time.perf_counter() - t0) * 1000 log.info( "solve: %s in %.1f ms — dof=%d, %d placements", "converged" if converged else "FAILED", elapsed, dof, len(result.placements), ) if not converged and result.diagnostics: for d in result.diagnostics: log.warning( " diagnostic: [%s] %s — %s", d.kind, d.constraint_id, d.detail ) return result # ── Incremental update ────────────────────────────────────────── # The base class default (delegates to solve()) is correct here: # solve() uses current placements as initial guess, so small # parameter changes converge quickly without special handling. # ── Interactive drag ──────────────────────────────────────────── def pre_drag(self, ctx, drag_parts): log.info("pre_drag: drag_parts=%s", drag_parts) self._drag_ctx = ctx self._drag_parts = set(drag_parts) self._drag_step_count = 0 # Build the system once and cache everything for drag_step() reuse. t0 = time.perf_counter() system = _build_system(ctx) half_spaces = compute_half_spaces( system.constraint_objs, system.constraint_indices, system.params, ) if half_spaces: post_step_fn = lambda p: apply_half_space_correction(p, half_spaces) else: post_step_fn = None residuals = substitution_pass(system.all_residuals, system.params) # NOTE: single_equation_pass is intentionally skipped for drag. # It permanently fixes variables and removes residuals from the # list. During drag the dragged part's parameters change each # frame, which can invalidate those analytic solutions and cause # constraints (e.g. Planar distance=0) to stop being enforced. # The substitution pass alone is safe because it only replaces # genuinely grounded parameters with constants. # Build weight vector *after* pre-passes so its length matches the # remaining free parameters (single_equation_pass may fix some). weight_vec = build_weight_vector(system.params) # Build symbolic Jacobian + compile once from .codegen import try_compile_system free = system.params.free_names() n_res = len(residuals) n_free = len(free) jac_exprs = [[r.diff(name).simplify() for name in free] for r in residuals] compiled_eval = try_compile_system(residuals, jac_exprs, n_res, n_free) # Initial solve converged = newton_solve( residuals, system.params, quat_groups=system.quat_groups, max_iter=100, tol=1e-10, post_step=post_step_fn, weight_vector=weight_vec, jac_exprs=jac_exprs, compiled_eval=compiled_eval, ) if not converged: converged = bfgs_solve( residuals, system.params, quat_groups=system.quat_groups, max_iter=200, tol=1e-10, weight_vector=weight_vec, jac_exprs=jac_exprs, compiled_eval=compiled_eval, ) # Cache for drag_step() reuse cache = _DragCache() cache.system = system cache.residuals = residuals cache.jac_exprs = jac_exprs cache.compiled_eval = compiled_eval cache.half_spaces = half_spaces cache.weight_vec = weight_vec cache.post_step_fn = post_step_fn # Snapshot solved quaternions for continuity enforcement in drag_step() env = system.params.get_env() cache.pre_step_quats = {} for body in system.bodies.values(): if not body.grounded: cache.pre_step_quats[body.part_id] = body.extract_quaternion(env) self._drag_cache = cache # Build result dof = count_dof(residuals, system.params, jac_exprs=jac_exprs) result = kcsolve.SolveResult() result.status = ( kcsolve.SolveStatus.Success if converged else kcsolve.SolveStatus.Failed ) result.dof = dof result.placements = _extract_placements(system.params, system.bodies) elapsed = (time.perf_counter() - t0) * 1000 log.info( "pre_drag: initial solve %s in %.1f ms — dof=%d", "converged" if converged else "FAILED", elapsed, dof, ) return result def drag_step(self, drag_placements): ctx = self._drag_ctx if ctx is None: log.warning("drag_step: no drag context (pre_drag not called?)") return kcsolve.SolveResult() self._drag_step_count = getattr(self, "_drag_step_count", 0) + 1 # Update dragged part placements in ctx (for caller consistency) for pr in drag_placements: for part in ctx.parts: if part.id == pr.id: part.placement = pr.placement break cache = getattr(self, "_drag_cache", None) if cache is None: # Fallback: no cache, do a full solve log.debug( "drag_step #%d: no cache, falling back to full solve", self._drag_step_count, ) return self.solve(ctx) t0 = time.perf_counter() params = cache.system.params # Update only the dragged part's 7 parameter values for pr in drag_placements: pfx = pr.id + "/" params.set_value(pfx + "tx", pr.placement.position[0]) params.set_value(pfx + "ty", pr.placement.position[1]) params.set_value(pfx + "tz", pr.placement.position[2]) params.set_value(pfx + "qw", pr.placement.quaternion[0]) params.set_value(pfx + "qx", pr.placement.quaternion[1]) params.set_value(pfx + "qy", pr.placement.quaternion[2]) params.set_value(pfx + "qz", pr.placement.quaternion[3]) # Solve with cached artifacts — no rebuild converged = newton_solve( cache.residuals, params, quat_groups=cache.system.quat_groups, max_iter=100, tol=1e-10, post_step=cache.post_step_fn, weight_vector=cache.weight_vec, jac_exprs=cache.jac_exprs, compiled_eval=cache.compiled_eval, ) if not converged: converged = bfgs_solve( cache.residuals, params, quat_groups=cache.system.quat_groups, max_iter=200, tol=1e-10, weight_vector=cache.weight_vec, jac_exprs=cache.jac_exprs, compiled_eval=cache.compiled_eval, ) # Quaternion continuity: ensure solved quaternions stay in the # same hemisphere as the previous step. q and -q encode the # same rotation, but the C++ side measures angle between the # old and new quaternion — if we're in the -q branch, that # shows up as a ~340° flip and gets rejected. dragged_ids = self._drag_parts or set() _enforce_quat_continuity( params, cache.system.bodies, cache.pre_step_quats, dragged_ids ) # Update the stored quaternions for the next drag step env = params.get_env() for body in cache.system.bodies.values(): if not body.grounded: cache.pre_step_quats[body.part_id] = body.extract_quaternion(env) result = kcsolve.SolveResult() result.status = ( kcsolve.SolveStatus.Success if converged else kcsolve.SolveStatus.Failed ) result.dof = -1 # skip DOF counting during drag for speed result.placements = _extract_placements(params, cache.system.bodies) elapsed = (time.perf_counter() - t0) * 1000 if not converged: log.warning( "drag_step #%d: solve FAILED in %.1f ms", self._drag_step_count, elapsed, ) else: log.debug( "drag_step #%d: ok in %.1f ms", self._drag_step_count, elapsed, ) return result def post_drag(self): steps = getattr(self, "_drag_step_count", 0) log.info("post_drag: completed after %d drag steps", steps) self._drag_ctx = None self._drag_parts = None self._drag_step_count = 0 self._drag_cache = None # ── Diagnostics ───────────────────────────────────────────────── def diagnose(self, ctx): system = _build_system(ctx) residuals = substitution_pass(system.all_residuals, system.params) return _run_diagnostics( residuals, system.params, system.residual_ranges, ctx, ) def is_deterministic(self): return True class _DragCache: """Cached artifacts from pre_drag() reused across drag_step() calls. During interactive drag the constraint topology is invariant — only the dragged part's parameter values change. Caching the built system, symbolic Jacobian, and compiled evaluator eliminates the expensive rebuild overhead (~150 ms) on every frame. """ __slots__ = ( "system", # _System — owns ParamTable + Expr trees "residuals", # list[Expr] — after substitution + single-equation pass "jac_exprs", # list[list[Expr]] — symbolic Jacobian "compiled_eval", # Callable or None "half_spaces", # list[HalfSpace] "weight_vec", # ndarray or None "post_step_fn", # Callable or None "pre_step_quats", # dict[str, tuple] — last-accepted quaternions per body ) class _System: """Intermediate representation of a built constraint system.""" __slots__ = ( "params", "bodies", "constraint_objs", "constraint_indices", "all_residuals", "residual_ranges", "quat_groups", ) def _enforce_quat_continuity( params: ParamTable, bodies: dict, pre_step_quats: dict, dragged_ids: set, ) -> None: """Ensure solved quaternions stay close to the previous step. Two levels of correction, applied to ALL non-grounded bodies (including dragged parts, whose params Newton re-solves): 1. **Hemisphere check** (cheap): if dot(q_prev, q_solved) < 0, negate q_solved. This catches the common q-vs-(-q) sign flip. 2. **Rotation angle check**: compute the rotation angle from q_prev to q_solved using the same formula as the C++ validator (2*acos(w) of the relative quaternion). If the angle exceeds the C++ threshold (91°), reset the body's quaternion to q_prev. This catches deeper branch jumps where the solver converged to a geometrically different but constraint-satisfying orientation. The next Newton iteration from the caller will re-converge from the safer starting point. This applies to dragged parts too: the GUI sets the dragged part's params to the mouse-projected placement, then Newton re-solves all free params (including the dragged part's) to satisfy constraints. The solver can converge to an equivalent quaternion on the opposite branch, which the C++ validateNewPlacements() rejects as a >91° flip. """ _MAX_ANGLE = 91.0 * math.pi / 180.0 # match C++ threshold for body in bodies.values(): if body.grounded: continue prev = pre_step_quats.get(body.part_id) if prev is None: continue pfx = body.part_id + "/" qw = params.get_value(pfx + "qw") qx = params.get_value(pfx + "qx") qy = params.get_value(pfx + "qy") qz = params.get_value(pfx + "qz") # Level 1: hemisphere check (standard SLERP short-arc correction) dot = prev[0] * qw + prev[1] * qx + prev[2] * qy + prev[3] * qz if dot < 0.0: qw, qx, qy, qz = -qw, -qx, -qy, -qz params.set_value(pfx + "qw", qw) params.set_value(pfx + "qx", qx) params.set_value(pfx + "qy", qy) params.set_value(pfx + "qz", qz) # Level 2: rotation angle check (catches branch jumps beyond sign flip) # Compute relative quaternion: q_rel = q_new * conj(q_prev) pw, px, py, pz = prev rel_w = qw * pw + qx * px + qy * py + qz * pz rel_x = qx * pw - qw * px - qy * pz + qz * py rel_y = qy * pw - qw * py - qz * px + qx * pz rel_z = qz * pw - qw * pz - qx * py + qy * px # Normalize rel_norm = math.sqrt( rel_w * rel_w + rel_x * rel_x + rel_y * rel_y + rel_z * rel_z ) if rel_norm > 1e-15: rel_w /= rel_norm rel_w = max(-1.0, min(1.0, rel_w)) # C++ evaluateVector: angle = 2 * acos(w) if -1.0 < rel_w < 1.0: angle = 2.0 * math.acos(rel_w) else: angle = 0.0 if abs(angle) > _MAX_ANGLE: # The solver jumped to a different constraint branch. # Reset to the previous step's quaternion — the caller's # Newton solve was already complete, so this just ensures # the output stays near the previous configuration. log.debug( "_enforce_quat_continuity: %s jumped %.1f deg, " "resetting to previous quaternion", body.part_id, math.degrees(angle), ) params.set_value(pfx + "qw", pw) params.set_value(pfx + "qx", px) params.set_value(pfx + "qy", py) params.set_value(pfx + "qz", pz) def _build_system(ctx): """Build the solver's internal representation from a SolveContext. Returns a _System with params, bodies, constraint objects, residuals, residual-to-constraint mapping, and quaternion groups. """ system = _System() params = ParamTable() bodies = {} # part_id -> RigidBody # 1. Build entities from parts for part in ctx.parts: pos = tuple(part.placement.position) quat = tuple(part.placement.quaternion) # (w, x, y, z) body = RigidBody( part.id, params, position=pos, quaternion=quat, grounded=part.grounded, ) bodies[part.id] = body # 2. Build constraint residuals (track index mapping for decomposition) all_residuals = [] constraint_objs = [] constraint_indices = [] # parallel to constraint_objs: index in ctx.constraints skipped_inactive = 0 skipped_missing_body = 0 skipped_unsupported = 0 for idx, c in enumerate(ctx.constraints): if not c.activated: skipped_inactive += 1 continue body_i = bodies.get(c.part_i) body_j = bodies.get(c.part_j) if body_i is None or body_j is None: skipped_missing_body += 1 log.debug( "_build_system: constraint[%d] %s skipped — missing body (%s or %s)", idx, c.id, c.part_i, c.part_j, ) continue marker_i_pos = tuple(c.marker_i.position) marker_j_pos = tuple(c.marker_j.position) obj = _build_constraint( c.type, body_i, marker_i_pos, body_j, marker_j_pos, c.marker_i, c.marker_j, c.params, ) if obj is None: skipped_unsupported += 1 log.debug( "_build_system: constraint[%d] %s type=%s — unsupported, skipped", idx, c.id, c.type, ) continue constraint_objs.append(obj) constraint_indices.append(idx) all_residuals.extend(obj.residuals()) if skipped_inactive or skipped_missing_body or skipped_unsupported: log.debug( "_build_system: skipped constraints — %d inactive, %d missing body, %d unsupported", skipped_inactive, skipped_missing_body, skipped_unsupported, ) # 3. Build residual-to-constraint mapping residual_ranges = [] # (start_row, end_row, constraint_idx) row = 0 for i, obj in enumerate(constraint_objs): n = len(obj.residuals()) residual_ranges.append((row, row + n, constraint_indices[i])) row += n # 4. Add quaternion normalization residuals for non-grounded bodies quat_groups = [] for body in bodies.values(): if not body.grounded: all_residuals.append(body.quat_norm_residual()) quat_groups.append(body.quat_param_names()) system.params = params system.bodies = bodies system.constraint_objs = constraint_objs system.constraint_indices = constraint_indices system.all_residuals = all_residuals system.residual_ranges = residual_ranges system.quat_groups = quat_groups return system def _run_diagnostics(residuals, params, residual_ranges, ctx, jac_exprs=None): """Run overconstrained detection and return kcsolve diagnostics.""" diagnostics = [] if not hasattr(kcsolve, "ConstraintDiagnostic"): return diagnostics diags = find_overconstrained( residuals, params, residual_ranges, jac_exprs=jac_exprs ) for d in diags: cd = kcsolve.ConstraintDiagnostic() cd.constraint_id = ctx.constraints[d.constraint_index].id cd.kind = ( kcsolve.DiagnosticKind.Redundant if d.kind == "redundant" else kcsolve.DiagnosticKind.Conflicting ) cd.detail = d.detail diagnostics.append(cd) return diagnostics def _extract_placements(params, bodies): """Extract solved placements from the parameter table.""" env = params.get_env() placements = [] for body in bodies.values(): if body.grounded: continue pr = kcsolve.SolveResult.PartResult() pr.id = body.part_id pr.placement = kcsolve.Transform() pr.placement.position = list(body.extract_position(env)) pr.placement.quaternion = list(body.extract_quaternion(env)) placements.append(pr) return placements def _monolithic_solve( all_residuals, params, quat_groups, post_step=None, weight_vector=None ): """Newton-Raphson solve with BFGS fallback on the full system. Returns ``(converged, jac_exprs)`` so the caller can reuse the symbolic Jacobian for DOF counting / diagnostics. """ from .codegen import try_compile_system free = params.free_names() n_res = len(all_residuals) n_free = len(free) # Build symbolic Jacobian once jac_exprs = [[r.diff(name).simplify() for name in free] for r in all_residuals] # Compile once compiled_eval = try_compile_system(all_residuals, jac_exprs, n_res, n_free) t0 = time.perf_counter() converged = newton_solve( all_residuals, params, quat_groups=quat_groups, max_iter=100, tol=1e-10, post_step=post_step, weight_vector=weight_vector, jac_exprs=jac_exprs, compiled_eval=compiled_eval, ) nr_ms = (time.perf_counter() - t0) * 1000 if not converged: log.info( "_monolithic_solve: Newton-Raphson failed (%.1f ms), trying BFGS", nr_ms ) t1 = time.perf_counter() converged = bfgs_solve( all_residuals, params, quat_groups=quat_groups, max_iter=200, tol=1e-10, weight_vector=weight_vector, jac_exprs=jac_exprs, compiled_eval=compiled_eval, ) bfgs_ms = (time.perf_counter() - t1) * 1000 if converged: log.info("_monolithic_solve: BFGS converged (%.1f ms)", bfgs_ms) else: log.warning("_monolithic_solve: BFGS also failed (%.1f ms)", bfgs_ms) else: log.debug("_monolithic_solve: Newton-Raphson converged (%.1f ms)", nr_ms) return converged, jac_exprs def _build_constraint( kind, body_i, marker_i_pos, body_j, marker_j_pos, marker_i, marker_j, c_params, ) -> ConstraintBase | None: """Create the appropriate constraint object from a BaseJointKind.""" marker_i_quat = tuple(marker_i.quaternion) marker_j_quat = tuple(marker_j.quaternion) # -- Phase 1 constraints -------------------------------------------------- if kind == kcsolve.BaseJointKind.Coincident: return CoincidentConstraint(body_i, marker_i_pos, body_j, marker_j_pos) if kind == kcsolve.BaseJointKind.DistancePointPoint: distance = c_params[0] if c_params else 0.0 if distance == 0.0: # Distance=0 is point coincidence. Use CoincidentConstraint # (3 linear residuals) instead of the squared form which has # a degenerate Jacobian when the constraint is satisfied. return CoincidentConstraint( body_i, marker_i_pos, body_j, marker_j_pos, ) return DistancePointPointConstraint( body_i, marker_i_pos, body_j, marker_j_pos, distance, ) if kind == kcsolve.BaseJointKind.Fixed: return FixedConstraint( body_i, marker_i_pos, marker_i_quat, body_j, marker_j_pos, marker_j_quat, ) # -- Phase 2: point constraints ------------------------------------------- if kind == kcsolve.BaseJointKind.PointOnLine: return PointOnLineConstraint( body_i, marker_i_pos, marker_i_quat, body_j, marker_j_pos, marker_j_quat, ) if kind == kcsolve.BaseJointKind.PointInPlane: offset = c_params[0] if c_params else 0.0 return PointInPlaneConstraint( body_i, marker_i_pos, marker_i_quat, body_j, marker_j_pos, marker_j_quat, offset=offset, ) # -- Phase 2: orientation constraints ------------------------------------- if kind == kcsolve.BaseJointKind.Parallel: return ParallelConstraint(body_i, marker_i_quat, body_j, marker_j_quat) if kind == kcsolve.BaseJointKind.Perpendicular: return PerpendicularConstraint(body_i, marker_i_quat, body_j, marker_j_quat) if kind == kcsolve.BaseJointKind.Angle: angle = c_params[0] if c_params else 0.0 return AngleConstraint(body_i, marker_i_quat, body_j, marker_j_quat, angle) # -- Phase 2: axis/surface constraints ------------------------------------ if kind == kcsolve.BaseJointKind.Concentric: distance = c_params[0] if c_params else 0.0 return ConcentricConstraint( body_i, marker_i_pos, marker_i_quat, body_j, marker_j_pos, marker_j_quat, distance=distance, ) if kind == kcsolve.BaseJointKind.Tangent: return TangentConstraint( body_i, marker_i_pos, marker_i_quat, body_j, marker_j_pos, marker_j_quat, ) if kind == kcsolve.BaseJointKind.Planar: offset = c_params[0] if c_params else 0.0 return PlanarConstraint( body_i, marker_i_pos, marker_i_quat, body_j, marker_j_pos, marker_j_quat, offset=offset, ) if kind == kcsolve.BaseJointKind.LineInPlane: offset = c_params[0] if c_params else 0.0 return LineInPlaneConstraint( body_i, marker_i_pos, marker_i_quat, body_j, marker_j_pos, marker_j_quat, offset=offset, ) # -- Phase 2: kinematic joints -------------------------------------------- if kind == kcsolve.BaseJointKind.Ball: return BallConstraint(body_i, marker_i_pos, body_j, marker_j_pos) if kind == kcsolve.BaseJointKind.Revolute: return RevoluteConstraint( body_i, marker_i_pos, marker_i_quat, body_j, marker_j_pos, marker_j_quat, ) if kind == kcsolve.BaseJointKind.Cylindrical: return CylindricalConstraint( body_i, marker_i_pos, marker_i_quat, body_j, marker_j_pos, marker_j_quat, ) if kind == kcsolve.BaseJointKind.Slider: return SliderConstraint( body_i, marker_i_pos, marker_i_quat, body_j, marker_j_pos, marker_j_quat, ) if kind == kcsolve.BaseJointKind.Screw: pitch = c_params[0] if c_params else 1.0 return ScrewConstraint( body_i, marker_i_pos, marker_i_quat, body_j, marker_j_pos, marker_j_quat, pitch=pitch, ) if kind == kcsolve.BaseJointKind.Universal: return UniversalConstraint( body_i, marker_i_pos, marker_i_quat, body_j, marker_j_pos, marker_j_quat, ) # -- Phase 2: mechanical constraints -------------------------------------- if kind == kcsolve.BaseJointKind.Gear: radius_i = c_params[0] if len(c_params) > 0 else 1.0 radius_j = c_params[1] if len(c_params) > 1 else 1.0 return GearConstraint( body_i, marker_i_quat, body_j, marker_j_quat, radius_i, radius_j, ) if kind == kcsolve.BaseJointKind.RackPinion: pitch_radius = c_params[0] if c_params else 1.0 return RackPinionConstraint( body_i, marker_i_pos, marker_i_quat, body_j, marker_j_pos, marker_j_quat, pitch_radius=pitch_radius, ) # -- Stubs (accepted but produce no residuals) ---------------------------- if kind == kcsolve.BaseJointKind.Cam: return CamConstraint() if kind == kcsolve.BaseJointKind.Slot: return SlotConstraint() if kind == kcsolve.BaseJointKind.DistanceCylSph: return DistanceCylSphConstraint() return None