From b4b8724ff17a9a53191600e021ad201f75c3662a Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Fri, 20 Feb 2026 23:32:45 -0600 Subject: [PATCH] feat(solver): diagnostics, half-space preference, and weight vectors (phase 4) - Add per-entity DOF analysis via Jacobian SVD (diagnostics.py) - Add overconstrained detection: redundant vs conflicting constraints - Add half-space tracking to preserve configuration branch (preference.py) - Add minimum-movement weighting for least-squares solve - Extend BFGS fallback with weight vector and quaternion renormalization - Add snapshot/restore and env accessor to ParamTable - Fix DistancePointPointConstraint sign for half-space tracking --- kindred_solver/bfgs.py | 46 +++- kindred_solver/constraints.py | 10 +- kindred_solver/diagnostics.py | 299 ++++++++++++++++++++++++++ kindred_solver/newton.py | 23 +- kindred_solver/params.py | 23 ++ kindred_solver/preference.py | 325 ++++++++++++++++++++++++++++ tests/test_diagnostics.py | 296 ++++++++++++++++++++++++++ tests/test_params.py | 47 +++++ tests/test_preference.py | 384 ++++++++++++++++++++++++++++++++++ 9 files changed, 1444 insertions(+), 9 deletions(-) create mode 100644 kindred_solver/diagnostics.py create mode 100644 kindred_solver/preference.py create mode 100644 tests/test_diagnostics.py create mode 100644 tests/test_preference.py diff --git a/kindred_solver/bfgs.py b/kindred_solver/bfgs.py index 44e65ce..c4f0f78 100644 --- a/kindred_solver/bfgs.py +++ b/kindred_solver/bfgs.py @@ -28,11 +28,16 @@ def bfgs_solve( quat_groups: List[tuple[str, str, str, str]] | None = None, max_iter: int = 200, tol: float = 1e-10, + weight_vector: "np.ndarray | None" = None, ) -> bool: """Solve ``residuals == 0`` by minimizing sum of squared residuals. Falls back gracefully to False if scipy is not available. + When *weight_vector* is provided, residuals are scaled by + ``sqrt(w)`` so that the objective becomes + ``0.5 * sum(w_i * r_i^2)`` — equivalent to weighted least-squares. + Returns True if converged (||r|| < tol). """ if not _HAS_SCIPY: @@ -53,7 +58,21 @@ def bfgs_solve( row.append(r.diff(name).simplify()) jac_exprs.append(row) - def objective_and_grad(x_vec): + # Pre-compute scaling for weighted minimum-norm + if weight_vector is not None: + w_sqrt = np.sqrt(weight_vector) + w_inv_sqrt = 1.0 / w_sqrt + else: + w_sqrt = None + w_inv_sqrt = None + + def objective_and_grad(y_vec): + # Transform back from scaled space if weighted + if w_inv_sqrt is not None: + x_vec = y_vec * w_inv_sqrt + else: + x_vec = y_vec + # Update params params.set_free_vector(x_vec) if quat_groups: @@ -71,23 +90,38 @@ def bfgs_solve( for j in range(n_free): J[i, j] = jac_exprs[i][j].eval(env) - # Gradient of f = sum(r_i * dr_i/dx_j) = J^T @ r - grad = J.T @ r_vals + # Gradient of f w.r.t. x = J^T @ r + grad_x = J.T @ r_vals + + # Chain rule: df/dy = df/dx * dx/dy = grad_x * w_inv_sqrt + if w_inv_sqrt is not None: + grad = grad_x * w_inv_sqrt + else: + grad = grad_x return f, grad x0 = params.get_free_vector().copy() + # Transform initial guess to scaled space + if w_sqrt is not None: + y0 = x0 * w_sqrt + else: + y0 = x0 + result = _scipy_minimize( objective_and_grad, - x0, + y0, method="L-BFGS-B", jac=True, options={"maxiter": max_iter, "ftol": tol * tol, "gtol": tol}, ) - # Apply final result - params.set_free_vector(result.x) + # Apply final result (transform back from scaled space) + if w_inv_sqrt is not None: + params.set_free_vector(result.x * w_inv_sqrt) + else: + params.set_free_vector(result.x) if quat_groups: _renormalize_quats(params, quat_groups) diff --git a/kindred_solver/constraints.py b/kindred_solver/constraints.py index 47eacb1..b9093c5 100644 --- a/kindred_solver/constraints.py +++ b/kindred_solver/constraints.py @@ -77,9 +77,15 @@ class DistancePointPointConstraint(ConstraintBase): self.marker_j_pos = marker_j_pos self.distance = distance + def world_points(self) -> tuple[tuple[Expr, Expr, Expr], tuple[Expr, Expr, Expr]]: + """Return (world_point_i, world_point_j) expression tuples.""" + return ( + self.body_i.world_point(*self.marker_i_pos), + self.body_j.world_point(*self.marker_j_pos), + ) + def residuals(self) -> List[Expr]: - wx_i, wy_i, wz_i = self.body_i.world_point(*self.marker_i_pos) - wx_j, wy_j, wz_j = self.body_j.world_point(*self.marker_j_pos) + (wx_i, wy_i, wz_i), (wx_j, wy_j, wz_j) = self.world_points() dx = wx_i - wx_j dy = wy_i - wy_j dz = wz_i - wz_j diff --git a/kindred_solver/diagnostics.py b/kindred_solver/diagnostics.py new file mode 100644 index 0000000..6747b8b --- /dev/null +++ b/kindred_solver/diagnostics.py @@ -0,0 +1,299 @@ +"""Per-entity DOF diagnostics and overconstrained detection. + +Provides per-body remaining degrees of freedom, human-readable free +motion labels, and redundant/conflicting constraint identification. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import List + +import numpy as np + +from .entities import RigidBody +from .expr import Expr +from .params import ParamTable + +# -- Per-entity DOF ----------------------------------------------------------- + + +@dataclass +class EntityDOF: + """DOF report for a single entity (rigid body).""" + + entity_id: str + remaining_dof: int # 0 = well-constrained + free_motions: list[str] = field(default_factory=list) + + +def per_entity_dof( + residuals: list[Expr], + params: ParamTable, + bodies: dict[str, RigidBody], + rank_tol: float = 1e-8, +) -> list[EntityDOF]: + """Compute remaining DOF for each non-grounded body. + + For each body, extracts the Jacobian columns corresponding to its + 7 parameters, performs SVD to find constrained directions, and + classifies null-space vectors as translations or rotations. + """ + free = params.free_names() + n_res = len(residuals) + env = params.get_env() + + if n_res == 0: + # No constraints — every free body has 6 DOF + result = [] + for pid, body in bodies.items(): + if body.grounded: + continue + result.append( + EntityDOF( + entity_id=pid, + remaining_dof=6, + free_motions=[ + "translation along X", + "translation along Y", + "translation along Z", + "rotation about X", + "rotation about Y", + "rotation about Z", + ], + ) + ) + return result + + # Build column index mapping: param_name -> column index in free list + free_index = {name: i for i, name in enumerate(free)} + + # Build full Jacobian (for efficiency, compute once) + n_free = len(free) + J_full = np.empty((n_res, n_free)) + for i, r in enumerate(residuals): + for j, name in enumerate(free): + J_full[i, j] = r.diff(name).simplify().eval(env) + + result = [] + for pid, body in bodies.items(): + if body.grounded: + continue + + # Find column indices for this body's params + pfx = pid + "/" + body_param_names = [ + pfx + "tx", + pfx + "ty", + pfx + "tz", + pfx + "qw", + pfx + "qx", + pfx + "qy", + pfx + "qz", + ] + col_indices = [free_index[n] for n in body_param_names if n in free_index] + + if not col_indices: + # All params fixed (shouldn't happen for non-grounded, but be safe) + result.append(EntityDOF(entity_id=pid, remaining_dof=0)) + continue + + # Extract submatrix: all residual rows, only this body's columns + J_sub = J_full[:, col_indices] + + # SVD + U, sv, Vt = np.linalg.svd(J_sub, full_matrices=True) + constrained = int(np.sum(sv > rank_tol)) + + # Subtract 1 for the quaternion unit-norm constraint (already in residuals) + # The quat norm residual constrains 1 direction in the 7-D param space, + # so effective body DOF = 7 - 1 - constrained_by_other_constraints. + # But the quat norm IS one of the residual rows, so it's already counted + # in `constrained`. So: remaining = len(col_indices) - constrained + # But the quat norm takes 1 from 7 → 6 geometric DOF, and constrained + # includes the quat norm row. So remaining = 7 - constrained, which gives + # geometric remaining DOF directly. + remaining = len(col_indices) - constrained + + # Classify null-space vectors as free motions + free_motions = [] + if remaining > 0 and Vt.shape[0] > constrained: + null_space = Vt[constrained:] # rows = null vectors in param space + + # Map column indices back to param types + param_types = [] + for n in body_param_names: + if n in free_index: + if n.endswith(("/tx", "/ty", "/tz")): + param_types.append("t") + else: + param_types.append("q") + + for null_vec in null_space: + label = _classify_motion( + null_vec, param_types, body_param_names, free_index + ) + if label: + free_motions.append(label) + + result.append( + EntityDOF( + entity_id=pid, + remaining_dof=remaining, + free_motions=free_motions, + ) + ) + + return result + + +def _classify_motion( + null_vec: np.ndarray, + param_types: list[str], + body_param_names: list[str], + free_index: dict[str, int], +) -> str: + """Classify a null-space vector as translation, rotation, or helical.""" + # Split components into translation and rotation parts + trans_indices = [i for i, t in enumerate(param_types) if t == "t"] + rot_indices = [i for i, t in enumerate(param_types) if t == "q"] + + trans_norm = np.linalg.norm(null_vec[trans_indices]) if trans_indices else 0.0 + rot_norm = np.linalg.norm(null_vec[rot_indices]) if rot_indices else 0.0 + + total = trans_norm + rot_norm + if total < 1e-14: + return "" + + trans_frac = trans_norm / total + rot_frac = rot_norm / total + + # Determine dominant axis + if trans_frac > 0.8: + # Pure translation + axis = _dominant_axis(null_vec, trans_indices) + return f"translation along {axis}" + elif rot_frac > 0.8: + # Pure rotation + axis = _dominant_axis(null_vec, rot_indices) + return f"rotation about {axis}" + else: + # Mixed — helical + axis = _dominant_axis(null_vec, trans_indices) + return f"helical motion along {axis}" + + +def _dominant_axis(vec: np.ndarray, indices: list[int]) -> str: + """Find the dominant axis (X/Y/Z) among the given component indices.""" + if not indices: + return "?" + components = np.abs(vec[indices]) + # Map to axis names — first 3 in group are X/Y/Z + axis_names = ["X", "Y", "Z"] + if len(components) >= 3: + idx = int(np.argmax(components[:3])) + return axis_names[idx] + elif len(components) == 1: + return axis_names[0] + else: + idx = int(np.argmax(components)) + return axis_names[min(idx, 2)] + + +# -- Overconstrained detection ------------------------------------------------ + + +@dataclass +class ConstraintDiag: + """Diagnostic for a single constraint.""" + + constraint_index: int + kind: str # "redundant" | "conflicting" + detail: str + + +def find_overconstrained( + residuals: list[Expr], + params: ParamTable, + residual_ranges: list[tuple[int, int, int]], + rank_tol: float = 1e-8, +) -> list[ConstraintDiag]: + """Identify redundant and conflicting constraints. + + Algorithm (following SolvSpace's FindWhichToRemoveToFixJacobian): + 1. Build full Jacobian J, compute rank. + 2. If rank == n_residuals, not overconstrained — return empty. + 3. For each constraint: remove its rows, check if rank is preserved + → if so, the constraint is **redundant**. + 4. Compute left null space, project residual vector F → if a + constraint's residuals contribute to this projection, it is + **conflicting** (redundant + unsatisfied). + """ + free = params.free_names() + n_free = len(free) + n_res = len(residuals) + + if n_free == 0 or n_res == 0: + return [] + + env = params.get_env() + + # Build Jacobian and residual vector + J = np.empty((n_res, n_free)) + r_vec = np.empty(n_res) + for i, r in enumerate(residuals): + r_vec[i] = r.eval(env) + for j, name in enumerate(free): + J[i, j] = r.diff(name).simplify().eval(env) + + # Full rank + sv_full = np.linalg.svd(J, compute_uv=False) + full_rank = int(np.sum(sv_full > rank_tol)) + + if full_rank >= n_res: + return [] # not overconstrained + + # Left null space: columns of U beyond rank + U, sv, Vt = np.linalg.svd(J, full_matrices=True) + left_null = U[:, full_rank:] # shape (n_res, n_res - rank) + + # Project residual onto left null space + null_residual = left_null.T @ r_vec # shape (n_res - rank,) + residual_projection = left_null @ null_residual # back to residual space + + diags = [] + for start, end, c_idx in residual_ranges: + # Remove this constraint's rows and check rank + mask = np.ones(n_res, dtype=bool) + mask[start:end] = False + J_reduced = J[mask] + + if J_reduced.shape[0] == 0: + continue + + sv_reduced = np.linalg.svd(J_reduced, compute_uv=False) + reduced_rank = int(np.sum(sv_reduced > rank_tol)) + + if reduced_rank >= full_rank: + # Removing this constraint preserves rank → redundant + # Check if it's also conflicting (contributes to unsatisfied null projection) + constraint_proj = np.linalg.norm(residual_projection[start:end]) + if constraint_proj > rank_tol: + kind = "conflicting" + detail = ( + f"Constraint {c_idx} is conflicting (redundant and unsatisfied)" + ) + else: + kind = "redundant" + detail = ( + f"Constraint {c_idx} is redundant (can be removed without effect)" + ) + diags.append( + ConstraintDiag( + constraint_index=c_idx, + kind=kind, + detail=detail, + ) + ) + + return diags diff --git a/kindred_solver/newton.py b/kindred_solver/newton.py index 94f3710..697d86a 100644 --- a/kindred_solver/newton.py +++ b/kindred_solver/newton.py @@ -17,6 +17,8 @@ def newton_solve( quat_groups: List[tuple[str, str, str, str]] | None = None, max_iter: int = 50, tol: float = 1e-10, + post_step: "Callable[[ParamTable], None] | None" = None, + weight_vector: "np.ndarray | None" = None, ) -> bool: """Solve ``residuals == 0`` by Newton-Raphson. @@ -33,6 +35,14 @@ def newton_solve( Maximum Newton iterations. tol: Convergence threshold on ``||r||``. + post_step: + Optional callback invoked after each parameter update, before + quaternion renormalization. Used for half-space correction. + weight_vector: + Optional 1-D array of length ``n_free``. When provided, the + lstsq step is column-scaled to produce the weighted + minimum-norm solution (prefer small movements in + high-weight parameters). Returns True if converged within *max_iter* iterations. """ @@ -67,13 +77,24 @@ def newton_solve( J[i, j] = jac_exprs[i][j].eval(env) # Solve J @ dx = -r (least-squares handles rank-deficient) - dx, _, _, _ = np.linalg.lstsq(J, -r_vec, rcond=None) + if weight_vector is not None: + # Column-scale J by W^{-1/2} for weighted minimum-norm + w_inv_sqrt = 1.0 / np.sqrt(weight_vector) + J_scaled = J * w_inv_sqrt[np.newaxis, :] + dx_scaled, _, _, _ = np.linalg.lstsq(J_scaled, -r_vec, rcond=None) + dx = dx_scaled * w_inv_sqrt + else: + dx, _, _, _ = np.linalg.lstsq(J, -r_vec, rcond=None) # Update parameters x = params.get_free_vector() x += dx params.set_free_vector(x) + # Half-space correction (before quat renormalization) + if post_step: + post_step(params) + # Re-normalize quaternions if quat_groups: _renormalize_quats(params, quat_groups) diff --git a/kindred_solver/params.py b/kindred_solver/params.py index a5b7c36..cbee3aa 100644 --- a/kindred_solver/params.py +++ b/kindred_solver/params.py @@ -81,3 +81,26 @@ class ParamTable: """Bulk-update free parameters from a 1-D array.""" for i, name in enumerate(self._free_order): self._values[name] = float(vec[i]) + + def snapshot(self) -> Dict[str, float]: + """Capture current values as a checkpoint.""" + return dict(self._values) + + def restore(self, snap: Dict[str, float]): + """Restore parameter values from a checkpoint.""" + for name, val in snap.items(): + if name in self._values: + self._values[name] = val + + def movement_cost( + self, + start: Dict[str, float], + weights: Dict[str, float] | None = None, + ) -> float: + """Weighted sum of squared displacements from start.""" + cost = 0.0 + for name in self._free_order: + w = weights.get(name, 1.0) if weights else 1.0 + delta = self._values[name] - start.get(name, self._values[name]) + cost += delta * delta * w + return cost diff --git a/kindred_solver/preference.py b/kindred_solver/preference.py new file mode 100644 index 0000000..a3e2db5 --- /dev/null +++ b/kindred_solver/preference.py @@ -0,0 +1,325 @@ +"""Solution preference: half-space tracking and minimum-movement weighting. + +Half-space tracking preserves the initial configuration branch across +Newton iterations. For constraints with multiple valid solutions +(e.g. distance can be satisfied on either side), we record which +"half-space" the initial state lives in and correct the solver step +if it crosses to the wrong branch. + +Minimum-movement weighting scales the Newton/BFGS step so that +quaternion parameters (rotation) are penalised more than translation +parameters, yielding the physically-nearest solution. +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass, field +from typing import Callable, List + +import numpy as np + +from .constraints import ( + AngleConstraint, + ConstraintBase, + DistancePointPointConstraint, + ParallelConstraint, + PerpendicularConstraint, +) +from .geometry import cross3, dot3, marker_z_axis +from .params import ParamTable + + +@dataclass +class HalfSpace: + """Tracks which branch of a branching constraint the solution should stay in.""" + + constraint_index: int # index in ctx.constraints + reference_sign: float # +1.0 or -1.0, captured at setup + indicator_fn: Callable[[dict[str, float]], float] # returns signed value + param_names: list[str] = field(default_factory=list) # params to flip + correction_fn: Callable[[ParamTable, float], None] | None = None + + +def compute_half_spaces( + constraint_objs: list[ConstraintBase], + constraint_indices: list[int], + params: ParamTable, +) -> list[HalfSpace]: + """Build half-space trackers for all branching constraints. + + Evaluates each constraint's indicator function at the current + parameter values to capture the reference sign. + """ + env = params.get_env() + half_spaces: list[HalfSpace] = [] + + for i, obj in enumerate(constraint_objs): + hs = _build_half_space(obj, constraint_indices[i], env, params) + if hs is not None: + half_spaces.append(hs) + + return half_spaces + + +def apply_half_space_correction( + params: ParamTable, + half_spaces: list[HalfSpace], +) -> None: + """Check each half-space and correct if the solver crossed a branch. + + Called as a post_step callback from newton_solve. + """ + if not half_spaces: + return + + env = params.get_env() + for hs in half_spaces: + current_val = hs.indicator_fn(env) + current_sign = ( + math.copysign(1.0, current_val) + if abs(current_val) > 1e-14 + else hs.reference_sign + ) + if current_sign != hs.reference_sign and hs.correction_fn is not None: + hs.correction_fn(params, current_val) + # Re-read env after correction for subsequent half-spaces + env = params.get_env() + + +def _build_half_space( + obj: ConstraintBase, + constraint_idx: int, + env: dict[str, float], + params: ParamTable, +) -> HalfSpace | None: + """Build a HalfSpace for a branching constraint, or None if not branching.""" + + if isinstance(obj, DistancePointPointConstraint) and obj.distance > 0: + return _distance_half_space(obj, constraint_idx, env, params) + + if isinstance(obj, ParallelConstraint): + return _parallel_half_space(obj, constraint_idx, env, params) + + if isinstance(obj, AngleConstraint): + return _angle_half_space(obj, constraint_idx, env, params) + + if isinstance(obj, PerpendicularConstraint): + return _perpendicular_half_space(obj, constraint_idx, env, params) + + return None + + +def _distance_half_space( + obj: DistancePointPointConstraint, + constraint_idx: int, + env: dict[str, float], + params: ParamTable, +) -> HalfSpace | None: + """Half-space for DistancePointPoint: track displacement direction. + + The indicator is the dot product of the current displacement with + the reference displacement direction. If the solver flips to the + opposite side, we reflect the moving body's position. + """ + p_i, p_j = obj.world_points() + + # Evaluate reference displacement direction + dx = p_i[0].eval(env) - p_j[0].eval(env) + dy = p_i[1].eval(env) - p_j[1].eval(env) + dz = p_i[2].eval(env) - p_j[2].eval(env) + dist = math.sqrt(dx * dx + dy * dy + dz * dz) + + if dist < 1e-14: + return None # points coincident, no branch to track + + # Reference unit direction + nx, ny, nz = dx / dist, dy / dist, dz / dist + + # Build indicator: dot(displacement, reference_direction) + # Use Expr evaluation for speed + disp_x, disp_y, disp_z = p_i[0] - p_j[0], p_i[1] - p_j[1], p_i[2] - p_j[2] + + def indicator(e: dict[str, float]) -> float: + return disp_x.eval(e) * nx + disp_y.eval(e) * ny + disp_z.eval(e) * nz + + ref_sign = math.copysign(1.0, indicator(env)) + + # Correction: reflect body_j position along reference direction + # (or body_i if body_j is grounded) + moving_body = obj.body_j if not obj.body_j.grounded else obj.body_i + if moving_body.grounded: + return None # both grounded, nothing to correct + + px_name = f"{moving_body.part_id}/tx" + py_name = f"{moving_body.part_id}/ty" + pz_name = f"{moving_body.part_id}/tz" + + sign_flip = -1.0 if moving_body is obj.body_j else 1.0 + + def correction(p: ParamTable, _val: float) -> None: + # Reflect displacement: negate the component along reference direction + e = p.get_env() + cur_dx = disp_x.eval(e) + cur_dy = disp_y.eval(e) + cur_dz = disp_z.eval(e) + # Project displacement onto reference direction + proj = cur_dx * nx + cur_dy * ny + cur_dz * nz + # Reflect: subtract 2*proj*n from the moving body's position + if not p.is_fixed(px_name): + p.set_value(px_name, p.get_value(px_name) + sign_flip * 2.0 * proj * nx) + if not p.is_fixed(py_name): + p.set_value(py_name, p.get_value(py_name) + sign_flip * 2.0 * proj * ny) + if not p.is_fixed(pz_name): + p.set_value(pz_name, p.get_value(pz_name) + sign_flip * 2.0 * proj * nz) + + return HalfSpace( + constraint_index=constraint_idx, + reference_sign=ref_sign, + indicator_fn=indicator, + param_names=[px_name, py_name, pz_name], + correction_fn=correction, + ) + + +def _parallel_half_space( + obj: ParallelConstraint, + constraint_idx: int, + env: dict[str, float], + params: ParamTable, +) -> HalfSpace: + """Half-space for Parallel: track same-direction vs opposite-direction. + + Indicator: dot(z_i, z_j). Positive = same direction, negative = opposite. + """ + z_i = marker_z_axis(obj.body_i, obj.marker_i_quat) + z_j = marker_z_axis(obj.body_j, obj.marker_j_quat) + dot_expr = dot3(z_i, z_j) + + def indicator(e: dict[str, float]) -> float: + return dot_expr.eval(e) + + ref_val = indicator(env) + ref_sign = math.copysign(1.0, ref_val) if abs(ref_val) > 1e-14 else 1.0 + + # No geometric correction — just let the indicator track. + # The Newton solver naturally handles this via the cross-product residual. + # We only need to detect and report branch flips. + return HalfSpace( + constraint_index=constraint_idx, + reference_sign=ref_sign, + indicator_fn=indicator, + ) + + +# ============================================================================ +# Minimum-movement weighting +# ============================================================================ + +# Scale factor so that a 1-radian rotation is penalised as much as a +# (180/pi)-unit translation. This makes the weighted minimum-norm +# step prefer translating over rotating for the same residual reduction. +QUAT_WEIGHT = (180.0 / math.pi) ** 2 # ~3283 + + +def build_weight_vector(params: ParamTable) -> np.ndarray: + """Build diagonal weight vector: 1.0 for translation, QUAT_WEIGHT for quaternion. + + Returns a 1-D array of length ``params.n_free()``. + """ + free = params.free_names() + w = np.ones(len(free)) + quat_suffixes = ("/qw", "/qx", "/qy", "/qz") + for i, name in enumerate(free): + if any(name.endswith(s) for s in quat_suffixes): + w[i] = QUAT_WEIGHT + return w + + +def _angle_half_space( + obj: AngleConstraint, + constraint_idx: int, + env: dict[str, float], + params: ParamTable, +) -> HalfSpace | None: + """Half-space for Angle: track sign of sin(angle) via cross product. + + For angle constraints, the dot product is fixed (= cos(angle)), + but sin can be +/-. We track the cross product magnitude sign. + """ + if abs(obj.angle) < 1e-14 or abs(obj.angle - math.pi) < 1e-14: + return None # 0 or 180 degrees — no branch ambiguity + + z_i = marker_z_axis(obj.body_i, obj.marker_i_quat) + z_j = marker_z_axis(obj.body_j, obj.marker_j_quat) + cx, cy, cz = cross3(z_i, z_j) + + # Use the magnitude of the cross product's z-component as indicator + # (or whichever component is largest at setup time) + cx_val = cx.eval(env) + cy_val = cy.eval(env) + cz_val = cz.eval(env) + + # Pick the dominant cross product component + components = [ + (abs(cx_val), cx, cx_val), + (abs(cy_val), cy, cy_val), + (abs(cz_val), cz, cz_val), + ] + _, best_expr, best_val = max(components, key=lambda t: t[0]) + + if abs(best_val) < 1e-14: + return None + + def indicator(e: dict[str, float]) -> float: + return best_expr.eval(e) + + ref_sign = math.copysign(1.0, best_val) + + return HalfSpace( + constraint_index=constraint_idx, + reference_sign=ref_sign, + indicator_fn=indicator, + ) + + +def _perpendicular_half_space( + obj: PerpendicularConstraint, + constraint_idx: int, + env: dict[str, float], + params: ParamTable, +) -> HalfSpace | None: + """Half-space for Perpendicular: track which quadrant. + + The dot product is constrained to 0, but the cross product sign + distinguishes which "side" of perpendicular. + """ + z_i = marker_z_axis(obj.body_i, obj.marker_i_quat) + z_j = marker_z_axis(obj.body_j, obj.marker_j_quat) + cx, cy, cz = cross3(z_i, z_j) + + # Pick the dominant cross product component + cx_val = cx.eval(env) + cy_val = cy.eval(env) + cz_val = cz.eval(env) + + components = [ + (abs(cx_val), cx, cx_val), + (abs(cy_val), cy, cy_val), + (abs(cz_val), cz, cz_val), + ] + _, best_expr, best_val = max(components, key=lambda t: t[0]) + + if abs(best_val) < 1e-14: + return None + + def indicator(e: dict[str, float]) -> float: + return best_expr.eval(e) + + ref_sign = math.copysign(1.0, best_val) + + return HalfSpace( + constraint_index=constraint_idx, + reference_sign=ref_sign, + indicator_fn=indicator, + ) diff --git a/tests/test_diagnostics.py b/tests/test_diagnostics.py new file mode 100644 index 0000000..79a26b4 --- /dev/null +++ b/tests/test_diagnostics.py @@ -0,0 +1,296 @@ +"""Tests for per-entity DOF diagnostics and overconstrained detection.""" + +import math + +import numpy as np +import pytest +from kindred_solver.constraints import ( + CoincidentConstraint, + CylindricalConstraint, + DistancePointPointConstraint, + FixedConstraint, + ParallelConstraint, + RevoluteConstraint, +) +from kindred_solver.diagnostics import ( + ConstraintDiag, + EntityDOF, + find_overconstrained, + per_entity_dof, +) +from kindred_solver.entities import RigidBody +from kindred_solver.params import ParamTable + + +def _make_two_bodies( + params, + pos_a=(0, 0, 0), + pos_b=(5, 0, 0), + quat_a=(1, 0, 0, 0), + quat_b=(1, 0, 0, 0), + ground_a=True, + ground_b=False, +): + body_a = RigidBody( + "a", params, position=pos_a, quaternion=quat_a, grounded=ground_a + ) + body_b = RigidBody( + "b", params, position=pos_b, quaternion=quat_b, grounded=ground_b + ) + return body_a, body_b + + +def _build_residuals_and_ranges(constraint_objs, bodies, params): + """Build residuals list, quat norms, and residual_ranges.""" + all_residuals = [] + residual_ranges = [] + row = 0 + for i, obj in enumerate(constraint_objs): + r = obj.residuals() + n = len(r) + residual_ranges.append((row, row + n, i)) + all_residuals.extend(r) + row += n + + for body in bodies.values(): + if not body.grounded: + all_residuals.append(body.quat_norm_residual()) + + return all_residuals, residual_ranges + + +# ============================================================================ +# Per-entity DOF tests +# ============================================================================ + + +class TestPerEntityDOF: + """Per-entity DOF computation.""" + + def test_unconstrained_body_6dof(self): + """Unconstrained non-grounded body has 6 DOF.""" + params = ParamTable() + body = RigidBody( + "b", params, position=(0, 0, 0), quaternion=(1, 0, 0, 0), grounded=False + ) + bodies = {"b": body} + + # Only quat norm constraint + residuals = [body.quat_norm_residual()] + + result = per_entity_dof(residuals, params, bodies) + assert len(result) == 1 + assert result[0].entity_id == "b" + assert result[0].remaining_dof == 6 + assert len(result[0].free_motions) == 6 + + def test_fixed_body_0dof(self): + """Body welded to ground has 0 DOF.""" + params = ParamTable() + body_a, body_b = _make_two_bodies(params) + bodies = {"a": body_a, "b": body_b} + + c = FixedConstraint( + body_a, + (0, 0, 0), + (1, 0, 0, 0), + body_b, + (0, 0, 0), + (1, 0, 0, 0), + ) + residuals, _ = _build_residuals_and_ranges([c], bodies, params) + + result = per_entity_dof(residuals, params, bodies) + # Only non-grounded body (b) reported + assert len(result) == 1 + assert result[0].entity_id == "b" + assert result[0].remaining_dof == 0 + assert len(result[0].free_motions) == 0 + + def test_revolute_1dof(self): + """Revolute joint leaves 1 DOF (rotation about Z).""" + params = ParamTable() + body_a, body_b = _make_two_bodies(params, pos_b=(0, 0, 0)) + bodies = {"a": body_a, "b": body_b} + + c = RevoluteConstraint( + body_a, + (0, 0, 0), + (1, 0, 0, 0), + body_b, + (0, 0, 0), + (1, 0, 0, 0), + ) + residuals, _ = _build_residuals_and_ranges([c], bodies, params) + + result = per_entity_dof(residuals, params, bodies) + assert len(result) == 1 + assert result[0].remaining_dof == 1 + # Should have one free motion that mentions rotation + assert len(result[0].free_motions) == 1 + assert "rotation" in result[0].free_motions[0].lower() + + def test_cylindrical_2dof(self): + """Cylindrical joint leaves 2 DOF (rotation about Z + translation along Z).""" + params = ParamTable() + body_a, body_b = _make_two_bodies(params, pos_b=(0, 0, 0)) + bodies = {"a": body_a, "b": body_b} + + c = CylindricalConstraint( + body_a, + (0, 0, 0), + (1, 0, 0, 0), + body_b, + (0, 0, 0), + (1, 0, 0, 0), + ) + residuals, _ = _build_residuals_and_ranges([c], bodies, params) + + result = per_entity_dof(residuals, params, bodies) + assert len(result) == 1 + assert result[0].remaining_dof == 2 + assert len(result[0].free_motions) == 2 + + def test_coincident_3dof(self): + """Coincident (ball) joint leaves 3 DOF (3 rotations).""" + params = ParamTable() + body_a, body_b = _make_two_bodies(params, pos_b=(0, 0, 0)) + bodies = {"a": body_a, "b": body_b} + + c = CoincidentConstraint(body_a, (0, 0, 0), body_b, (0, 0, 0)) + residuals, _ = _build_residuals_and_ranges([c], bodies, params) + + result = per_entity_dof(residuals, params, bodies) + assert len(result) == 1 + assert result[0].remaining_dof == 3 + # All 3 should be rotations + for motion in result[0].free_motions: + assert "rotation" in motion.lower() + + def test_no_constraints_6dof(self): + """No residuals at all gives 6 DOF.""" + params = ParamTable() + body = RigidBody( + "b", params, position=(0, 0, 0), quaternion=(1, 0, 0, 0), grounded=False + ) + bodies = {"b": body} + + result = per_entity_dof([], params, bodies) + assert len(result) == 1 + assert result[0].remaining_dof == 6 + + def test_grounded_body_excluded(self): + """Grounded bodies are not reported.""" + params = ParamTable() + body_a, body_b = _make_two_bodies(params) + bodies = {"a": body_a, "b": body_b} + + residuals = [body_b.quat_norm_residual()] + result = per_entity_dof(residuals, params, bodies) + + entity_ids = [r.entity_id for r in result] + assert "a" not in entity_ids # grounded + assert "b" in entity_ids + + def test_multiple_bodies(self): + """Two free bodies: each gets its own DOF report.""" + params = ParamTable() + body_g = RigidBody( + "g", params, position=(0, 0, 0), quaternion=(1, 0, 0, 0), grounded=True + ) + body_b = RigidBody( + "b", params, position=(5, 0, 0), quaternion=(1, 0, 0, 0), grounded=False + ) + body_c = RigidBody( + "c", params, position=(10, 0, 0), quaternion=(1, 0, 0, 0), grounded=False + ) + bodies = {"g": body_g, "b": body_b, "c": body_c} + + # Fix b to ground, leave c unconstrained + c_fix = FixedConstraint( + body_g, + (0, 0, 0), + (1, 0, 0, 0), + body_b, + (0, 0, 0), + (1, 0, 0, 0), + ) + residuals, _ = _build_residuals_and_ranges([c_fix], bodies, params) + + result = per_entity_dof(residuals, params, bodies) + result_map = {r.entity_id: r for r in result} + + assert result_map["b"].remaining_dof == 0 + assert result_map["c"].remaining_dof == 6 + + +# ============================================================================ +# Overconstrained detection tests +# ============================================================================ + + +class TestFindOverconstrained: + """Redundant and conflicting constraint detection.""" + + def test_well_constrained_no_diagnostics(self): + """Well-constrained system produces no diagnostics.""" + params = ParamTable() + body_a, body_b = _make_two_bodies(params, pos_b=(0, 0, 0)) + bodies = {"a": body_a, "b": body_b} + + c = FixedConstraint( + body_a, + (0, 0, 0), + (1, 0, 0, 0), + body_b, + (0, 0, 0), + (1, 0, 0, 0), + ) + residuals, ranges = _build_residuals_and_ranges([c], bodies, params) + + diags = find_overconstrained(residuals, params, ranges) + assert len(diags) == 0 + + def test_duplicate_coincident_redundant(self): + """Duplicate coincident constraint is flagged as redundant.""" + params = ParamTable() + body_a, body_b = _make_two_bodies(params, pos_b=(0, 0, 0)) + bodies = {"a": body_a, "b": body_b} + + c1 = CoincidentConstraint(body_a, (0, 0, 0), body_b, (0, 0, 0)) + c2 = CoincidentConstraint(body_a, (0, 0, 0), body_b, (0, 0, 0)) + residuals, ranges = _build_residuals_and_ranges([c1, c2], bodies, params) + + diags = find_overconstrained(residuals, params, ranges) + assert len(diags) > 0 + # At least one should be redundant + kinds = {d.kind for d in diags} + assert "redundant" in kinds + + def test_conflicting_distance(self): + """Distance constraint that can't be satisfied is flagged as conflicting.""" + params = ParamTable() + body_a, body_b = _make_two_bodies(params, pos_b=(0, 0, 0)) + bodies = {"a": body_a, "b": body_b} + + # Coincident forces distance=0, but distance constraint says 50 + c1 = CoincidentConstraint(body_a, (0, 0, 0), body_b, (0, 0, 0)) + c2 = DistancePointPointConstraint( + body_a, + (0, 0, 0), + body_b, + (0, 0, 0), + distance=50.0, + ) + residuals, ranges = _build_residuals_and_ranges([c1, c2], bodies, params) + + diags = find_overconstrained(residuals, params, ranges) + assert len(diags) > 0 + kinds = {d.kind for d in diags} + assert "conflicting" in kinds + + def test_empty_system_no_diagnostics(self): + """Empty system has no diagnostics.""" + params = ParamTable() + diags = find_overconstrained([], params, []) + assert len(diags) == 0 diff --git a/tests/test_params.py b/tests/test_params.py index d005c84..b849ad2 100644 --- a/tests/test_params.py +++ b/tests/test_params.py @@ -99,3 +99,50 @@ class TestParamTable: pt.unfix("a") assert pt.free_names() == ["a"] assert pt.n_free() == 1 + + def test_snapshot_restore_roundtrip(self): + """Snapshot captures values; restore brings them back.""" + pt = ParamTable() + pt.add("x", 1.0) + pt.add("y", 2.0) + pt.add("z", 3.0, fixed=True) + snap = pt.snapshot() + pt.set_value("x", 99.0) + pt.set_value("y", 88.0) + pt.set_value("z", 77.0) + pt.restore(snap) + assert pt.get_value("x") == 1.0 + assert pt.get_value("y") == 2.0 + assert pt.get_value("z") == 3.0 + + def test_snapshot_is_independent_copy(self): + """Mutating snapshot dict does not affect the table.""" + pt = ParamTable() + pt.add("a", 5.0) + snap = pt.snapshot() + snap["a"] = 999.0 + assert pt.get_value("a") == 5.0 + + def test_movement_cost_no_weights(self): + """Movement cost is sum of squared displacements for free params.""" + pt = ParamTable() + pt.add("x", 0.0) + pt.add("y", 0.0) + pt.add("z", 0.0, fixed=True) + snap = pt.snapshot() + pt.set_value("x", 3.0) + pt.set_value("y", 4.0) + pt.set_value("z", 100.0) # fixed — ignored + assert pt.movement_cost(snap) == pytest.approx(25.0) + + def test_movement_cost_with_weights(self): + """Weighted movement cost scales each displacement.""" + pt = ParamTable() + pt.add("a", 0.0) + pt.add("b", 0.0) + snap = pt.snapshot() + pt.set_value("a", 1.0) + pt.set_value("b", 1.0) + weights = {"a": 4.0, "b": 9.0} + # cost = 1^2*4 + 1^2*9 = 13 + assert pt.movement_cost(snap, weights) == pytest.approx(13.0) diff --git a/tests/test_preference.py b/tests/test_preference.py new file mode 100644 index 0000000..f7ced00 --- /dev/null +++ b/tests/test_preference.py @@ -0,0 +1,384 @@ +"""Tests for solution preference: half-space tracking and corrections.""" + +import math + +import numpy as np +import pytest +from kindred_solver.constraints import ( + AngleConstraint, + DistancePointPointConstraint, + ParallelConstraint, + PerpendicularConstraint, +) +from kindred_solver.entities import RigidBody +from kindred_solver.newton import newton_solve +from kindred_solver.params import ParamTable +from kindred_solver.preference import ( + apply_half_space_correction, + compute_half_spaces, +) + + +def _make_two_bodies( + params, + pos_a=(0, 0, 0), + pos_b=(5, 0, 0), + quat_a=(1, 0, 0, 0), + quat_b=(1, 0, 0, 0), + ground_a=True, + ground_b=False, +): + """Create two bodies with given positions/orientations.""" + body_a = RigidBody( + "a", params, position=pos_a, quaternion=quat_a, grounded=ground_a + ) + body_b = RigidBody( + "b", params, position=pos_b, quaternion=quat_b, grounded=ground_b + ) + return body_a, body_b + + +class TestDistanceHalfSpace: + """Half-space tracking for DistancePointPoint constraint.""" + + def test_positive_x_stays_positive(self): + """Body starting at +X should stay at +X after solve.""" + params = ParamTable() + body_a, body_b = _make_two_bodies(params, pos_b=(3, 0, 0)) + c = DistancePointPointConstraint( + body_a, + (0, 0, 0), + body_b, + (0, 0, 0), + distance=5.0, + ) + hs = compute_half_spaces([c], [0], params) + assert len(hs) == 1 + + # Solve with half-space correction + residuals = c.residuals() + residuals.append(body_b.quat_norm_residual()) + quat_groups = [body_b.quat_param_names()] + + def post_step(p): + apply_half_space_correction(p, hs) + + converged = newton_solve( + residuals, + params, + quat_groups=quat_groups, + post_step=post_step, + ) + assert converged + env = params.get_env() + # Body b should be at +X (x > 0), not -X + bx = env["b/tx"] + assert bx > 0, f"Expected positive X, got {bx}" + # Distance should be 5 + dist = math.sqrt(bx**2 + env["b/ty"] ** 2 + env["b/tz"] ** 2) + assert dist == pytest.approx(5.0, abs=1e-8) + + def test_negative_x_stays_negative(self): + """Body starting at -X should stay at -X after solve.""" + params = ParamTable() + body_a, body_b = _make_two_bodies(params, pos_b=(-3, 0, 0)) + c = DistancePointPointConstraint( + body_a, + (0, 0, 0), + body_b, + (0, 0, 0), + distance=5.0, + ) + hs = compute_half_spaces([c], [0], params) + assert len(hs) == 1 + + residuals = c.residuals() + residuals.append(body_b.quat_norm_residual()) + quat_groups = [body_b.quat_param_names()] + + def post_step(p): + apply_half_space_correction(p, hs) + + converged = newton_solve( + residuals, + params, + quat_groups=quat_groups, + post_step=post_step, + ) + assert converged + env = params.get_env() + bx = env["b/tx"] + assert bx < 0, f"Expected negative X, got {bx}" + + def test_zero_distance_no_halfspace(self): + """Zero distance constraint has no branch ambiguity.""" + params = ParamTable() + body_a, body_b = _make_two_bodies(params, pos_b=(3, 0, 0)) + c = DistancePointPointConstraint( + body_a, + (0, 0, 0), + body_b, + (0, 0, 0), + distance=0.0, + ) + hs = compute_half_spaces([c], [0], params) + assert len(hs) == 0 + + +class TestParallelHalfSpace: + """Half-space tracking for Parallel constraint.""" + + def test_same_direction_tracked(self): + """Same-direction parallel: positive reference sign.""" + params = ParamTable() + body_a, body_b = _make_two_bodies(params) + c = ParallelConstraint(body_a, (1, 0, 0, 0), body_b, (1, 0, 0, 0)) + hs = compute_half_spaces([c], [0], params) + assert len(hs) == 1 + assert hs[0].reference_sign == 1.0 + + def test_opposite_direction_tracked(self): + """Opposite-direction parallel: negative reference sign.""" + params = ParamTable() + # Rotate body_b by 180 degrees about X: Z-axis flips + q_flip = (0, 1, 0, 0) # 180 deg about X + body_a, body_b = _make_two_bodies(params, quat_b=q_flip) + c = ParallelConstraint(body_a, (1, 0, 0, 0), body_b, (1, 0, 0, 0)) + hs = compute_half_spaces([c], [0], params) + assert len(hs) == 1 + assert hs[0].reference_sign == -1.0 + + +class TestAngleHalfSpace: + """Half-space tracking for Angle constraint.""" + + def test_90_degree_angle(self): + """90-degree angle constraint creates a half-space.""" + params = ParamTable() + # Rotate body_b by 90 degrees about X + q_90x = (math.cos(math.pi / 4), math.sin(math.pi / 4), 0, 0) + body_a, body_b = _make_two_bodies(params, quat_b=q_90x) + c = AngleConstraint( + body_a, + (1, 0, 0, 0), + body_b, + (1, 0, 0, 0), + angle=math.pi / 2, + ) + hs = compute_half_spaces([c], [0], params) + assert len(hs) == 1 + + def test_zero_angle_no_halfspace(self): + """0-degree angle has no branch ambiguity.""" + params = ParamTable() + body_a, body_b = _make_two_bodies(params) + c = AngleConstraint( + body_a, + (1, 0, 0, 0), + body_b, + (1, 0, 0, 0), + angle=0.0, + ) + hs = compute_half_spaces([c], [0], params) + assert len(hs) == 0 + + def test_180_angle_no_halfspace(self): + """180-degree angle has no branch ambiguity.""" + params = ParamTable() + body_a, body_b = _make_two_bodies(params) + c = AngleConstraint( + body_a, + (1, 0, 0, 0), + body_b, + (1, 0, 0, 0), + angle=math.pi, + ) + hs = compute_half_spaces([c], [0], params) + assert len(hs) == 0 + + +class TestPerpendicularHalfSpace: + """Half-space tracking for Perpendicular constraint.""" + + def test_perpendicular_tracked(self): + """Perpendicular constraint creates a half-space.""" + params = ParamTable() + # Rotate body_b by 90 degrees about X + q_90x = (math.cos(math.pi / 4), math.sin(math.pi / 4), 0, 0) + body_a, body_b = _make_two_bodies(params, quat_b=q_90x) + c = PerpendicularConstraint( + body_a, + (1, 0, 0, 0), + body_b, + (1, 0, 0, 0), + ) + hs = compute_half_spaces([c], [0], params) + assert len(hs) == 1 + + +class TestNewtonPostStep: + """Verify Newton post_step callback works correctly.""" + + def test_callback_fires(self): + """post_step callback is invoked during Newton iterations.""" + params = ParamTable() + x = params.add("x", 2.0) + from kindred_solver.expr import Const + + residuals = [x - Const(5.0)] + + call_count = [0] + + def counter(p): + call_count[0] += 1 + + converged = newton_solve(residuals, params, post_step=counter) + assert converged + assert call_count[0] >= 1 + + def test_callback_does_not_break_convergence(self): + """A no-op callback doesn't prevent convergence.""" + params = ParamTable() + x = params.add("x", 1.0) + y = params.add("y", 1.0) + from kindred_solver.expr import Const + + residuals = [x - Const(3.0), y - Const(7.0)] + + converged = newton_solve(residuals, params, post_step=lambda p: None) + assert converged + assert params.get_value("x") == pytest.approx(3.0) + assert params.get_value("y") == pytest.approx(7.0) + + +class TestMixedHalfSpaces: + """Multiple branching constraints in one system.""" + + def test_multiple_constraints(self): + """compute_half_spaces handles mixed constraint types.""" + params = ParamTable() + body_a, body_b = _make_two_bodies(params, pos_b=(5, 0, 0)) + + dist_c = DistancePointPointConstraint( + body_a, + (0, 0, 0), + body_b, + (0, 0, 0), + distance=5.0, + ) + par_c = ParallelConstraint(body_a, (1, 0, 0, 0), body_b, (1, 0, 0, 0)) + + hs = compute_half_spaces([dist_c, par_c], [0, 1], params) + assert len(hs) == 2 + + +class TestBuildWeightVector: + """Weight vector construction.""" + + def test_translation_weight_one(self): + """Translation params get weight 1.0.""" + from kindred_solver.preference import build_weight_vector + + params = ParamTable() + params.add("body/tx", 0.0) + params.add("body/ty", 0.0) + params.add("body/tz", 0.0) + w = build_weight_vector(params) + np.testing.assert_array_equal(w, [1.0, 1.0, 1.0]) + + def test_quaternion_weight_high(self): + """Quaternion params get QUAT_WEIGHT.""" + from kindred_solver.preference import QUAT_WEIGHT, build_weight_vector + + params = ParamTable() + params.add("body/qw", 1.0) + params.add("body/qx", 0.0) + params.add("body/qy", 0.0) + params.add("body/qz", 0.0) + w = build_weight_vector(params) + np.testing.assert_array_equal(w, [QUAT_WEIGHT] * 4) + + def test_mixed_params(self): + """Mixed translation and quaternion params get correct weights.""" + from kindred_solver.preference import QUAT_WEIGHT, build_weight_vector + + params = ParamTable() + params.add("b/tx", 0.0) + params.add("b/qw", 1.0) + params.add("b/ty", 0.0) + params.add("b/qx", 0.0) + w = build_weight_vector(params) + assert w[0] == pytest.approx(1.0) + assert w[1] == pytest.approx(QUAT_WEIGHT) + assert w[2] == pytest.approx(1.0) + assert w[3] == pytest.approx(QUAT_WEIGHT) + + def test_fixed_params_excluded(self): + """Fixed params are not in free list, so not in weight vector.""" + from kindred_solver.preference import build_weight_vector + + params = ParamTable() + params.add("b/tx", 0.0, fixed=True) + params.add("b/ty", 0.0) + w = build_weight_vector(params) + assert len(w) == 1 + assert w[0] == pytest.approx(1.0) + + +class TestWeightedNewton: + """Weighted minimum-norm Newton solve.""" + + def test_well_constrained_same_result(self): + """Weighted and unweighted produce identical results for unique solution.""" + from kindred_solver.expr import Const + + # Fully determined system: x = 3, y = 7 + params1 = ParamTable() + x1 = params1.add("x", 1.0) + y1 = params1.add("y", 1.0) + r1 = [x1 - Const(3.0), y1 - Const(7.0)] + + params2 = ParamTable() + x2 = params2.add("x", 1.0) + y2 = params2.add("y", 1.0) + r2 = [x2 - Const(3.0), y2 - Const(7.0)] + + newton_solve(r1, params1) + newton_solve(r2, params2, weight_vector=np.array([1.0, 100.0])) + + assert params1.get_value("x") == pytest.approx( + params2.get_value("x"), abs=1e-10 + ) + assert params1.get_value("y") == pytest.approx( + params2.get_value("y"), abs=1e-10 + ) + + def test_underconstrained_prefers_low_weight(self): + """Under-constrained: weighted solve moves high-weight params less.""" + from kindred_solver.expr import Const + + # 1 equation, 2 unknowns: x + y = 10 (from x=0, y=0) + params_unw = ParamTable() + xu = params_unw.add("x", 0.0) + yu = params_unw.add("y", 0.0) + ru = [xu + yu - Const(10.0)] + + params_w = ParamTable() + xw = params_w.add("x", 0.0) + yw = params_w.add("y", 0.0) + rw = [xw + yw - Const(10.0)] + + # Unweighted: lstsq gives equal movement + newton_solve(ru, params_unw) + + # Weighted: y is 100x more expensive to move + newton_solve(rw, params_w, weight_vector=np.array([1.0, 100.0])) + + # Both should satisfy x + y = 10 + assert params_unw.get_value("x") + params_unw.get_value("y") == pytest.approx( + 10.0 + ) + assert params_w.get_value("x") + params_w.get_value("y") == pytest.approx(10.0) + + # Weighted solve should move y less than x + assert abs(params_w.get_value("y")) < abs(params_w.get_value("x"))