Files
solver/kindred_solver/preference.py
forbes-0023 b4b8724ff1 feat(solver): diagnostics, half-space preference, and weight vectors (phase 4)
- Add per-entity DOF analysis via Jacobian SVD (diagnostics.py)
- Add overconstrained detection: redundant vs conflicting constraints
- Add half-space tracking to preserve configuration branch (preference.py)
- Add minimum-movement weighting for least-squares solve
- Extend BFGS fallback with weight vector and quaternion renormalization
- Add snapshot/restore and env accessor to ParamTable
- Fix DistancePointPointConstraint sign for half-space tracking
2026-02-20 23:32:45 -06:00

326 lines
10 KiB
Python

"""Solution preference: half-space tracking and minimum-movement weighting.
Half-space tracking preserves the initial configuration branch across
Newton iterations. For constraints with multiple valid solutions
(e.g. distance can be satisfied on either side), we record which
"half-space" the initial state lives in and correct the solver step
if it crosses to the wrong branch.
Minimum-movement weighting scales the Newton/BFGS step so that
quaternion parameters (rotation) are penalised more than translation
parameters, yielding the physically-nearest solution.
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from typing import Callable, List
import numpy as np
from .constraints import (
AngleConstraint,
ConstraintBase,
DistancePointPointConstraint,
ParallelConstraint,
PerpendicularConstraint,
)
from .geometry import cross3, dot3, marker_z_axis
from .params import ParamTable
@dataclass
class HalfSpace:
"""Tracks which branch of a branching constraint the solution should stay in."""
constraint_index: int # index in ctx.constraints
reference_sign: float # +1.0 or -1.0, captured at setup
indicator_fn: Callable[[dict[str, float]], float] # returns signed value
param_names: list[str] = field(default_factory=list) # params to flip
correction_fn: Callable[[ParamTable, float], None] | None = None
def compute_half_spaces(
constraint_objs: list[ConstraintBase],
constraint_indices: list[int],
params: ParamTable,
) -> list[HalfSpace]:
"""Build half-space trackers for all branching constraints.
Evaluates each constraint's indicator function at the current
parameter values to capture the reference sign.
"""
env = params.get_env()
half_spaces: list[HalfSpace] = []
for i, obj in enumerate(constraint_objs):
hs = _build_half_space(obj, constraint_indices[i], env, params)
if hs is not None:
half_spaces.append(hs)
return half_spaces
def apply_half_space_correction(
params: ParamTable,
half_spaces: list[HalfSpace],
) -> None:
"""Check each half-space and correct if the solver crossed a branch.
Called as a post_step callback from newton_solve.
"""
if not half_spaces:
return
env = params.get_env()
for hs in half_spaces:
current_val = hs.indicator_fn(env)
current_sign = (
math.copysign(1.0, current_val)
if abs(current_val) > 1e-14
else hs.reference_sign
)
if current_sign != hs.reference_sign and hs.correction_fn is not None:
hs.correction_fn(params, current_val)
# Re-read env after correction for subsequent half-spaces
env = params.get_env()
def _build_half_space(
obj: ConstraintBase,
constraint_idx: int,
env: dict[str, float],
params: ParamTable,
) -> HalfSpace | None:
"""Build a HalfSpace for a branching constraint, or None if not branching."""
if isinstance(obj, DistancePointPointConstraint) and obj.distance > 0:
return _distance_half_space(obj, constraint_idx, env, params)
if isinstance(obj, ParallelConstraint):
return _parallel_half_space(obj, constraint_idx, env, params)
if isinstance(obj, AngleConstraint):
return _angle_half_space(obj, constraint_idx, env, params)
if isinstance(obj, PerpendicularConstraint):
return _perpendicular_half_space(obj, constraint_idx, env, params)
return None
def _distance_half_space(
obj: DistancePointPointConstraint,
constraint_idx: int,
env: dict[str, float],
params: ParamTable,
) -> HalfSpace | None:
"""Half-space for DistancePointPoint: track displacement direction.
The indicator is the dot product of the current displacement with
the reference displacement direction. If the solver flips to the
opposite side, we reflect the moving body's position.
"""
p_i, p_j = obj.world_points()
# Evaluate reference displacement direction
dx = p_i[0].eval(env) - p_j[0].eval(env)
dy = p_i[1].eval(env) - p_j[1].eval(env)
dz = p_i[2].eval(env) - p_j[2].eval(env)
dist = math.sqrt(dx * dx + dy * dy + dz * dz)
if dist < 1e-14:
return None # points coincident, no branch to track
# Reference unit direction
nx, ny, nz = dx / dist, dy / dist, dz / dist
# Build indicator: dot(displacement, reference_direction)
# Use Expr evaluation for speed
disp_x, disp_y, disp_z = p_i[0] - p_j[0], p_i[1] - p_j[1], p_i[2] - p_j[2]
def indicator(e: dict[str, float]) -> float:
return disp_x.eval(e) * nx + disp_y.eval(e) * ny + disp_z.eval(e) * nz
ref_sign = math.copysign(1.0, indicator(env))
# Correction: reflect body_j position along reference direction
# (or body_i if body_j is grounded)
moving_body = obj.body_j if not obj.body_j.grounded else obj.body_i
if moving_body.grounded:
return None # both grounded, nothing to correct
px_name = f"{moving_body.part_id}/tx"
py_name = f"{moving_body.part_id}/ty"
pz_name = f"{moving_body.part_id}/tz"
sign_flip = -1.0 if moving_body is obj.body_j else 1.0
def correction(p: ParamTable, _val: float) -> None:
# Reflect displacement: negate the component along reference direction
e = p.get_env()
cur_dx = disp_x.eval(e)
cur_dy = disp_y.eval(e)
cur_dz = disp_z.eval(e)
# Project displacement onto reference direction
proj = cur_dx * nx + cur_dy * ny + cur_dz * nz
# Reflect: subtract 2*proj*n from the moving body's position
if not p.is_fixed(px_name):
p.set_value(px_name, p.get_value(px_name) + sign_flip * 2.0 * proj * nx)
if not p.is_fixed(py_name):
p.set_value(py_name, p.get_value(py_name) + sign_flip * 2.0 * proj * ny)
if not p.is_fixed(pz_name):
p.set_value(pz_name, p.get_value(pz_name) + sign_flip * 2.0 * proj * nz)
return HalfSpace(
constraint_index=constraint_idx,
reference_sign=ref_sign,
indicator_fn=indicator,
param_names=[px_name, py_name, pz_name],
correction_fn=correction,
)
def _parallel_half_space(
obj: ParallelConstraint,
constraint_idx: int,
env: dict[str, float],
params: ParamTable,
) -> HalfSpace:
"""Half-space for Parallel: track same-direction vs opposite-direction.
Indicator: dot(z_i, z_j). Positive = same direction, negative = opposite.
"""
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
dot_expr = dot3(z_i, z_j)
def indicator(e: dict[str, float]) -> float:
return dot_expr.eval(e)
ref_val = indicator(env)
ref_sign = math.copysign(1.0, ref_val) if abs(ref_val) > 1e-14 else 1.0
# No geometric correction — just let the indicator track.
# The Newton solver naturally handles this via the cross-product residual.
# We only need to detect and report branch flips.
return HalfSpace(
constraint_index=constraint_idx,
reference_sign=ref_sign,
indicator_fn=indicator,
)
# ============================================================================
# Minimum-movement weighting
# ============================================================================
# Scale factor so that a 1-radian rotation is penalised as much as a
# (180/pi)-unit translation. This makes the weighted minimum-norm
# step prefer translating over rotating for the same residual reduction.
QUAT_WEIGHT = (180.0 / math.pi) ** 2 # ~3283
def build_weight_vector(params: ParamTable) -> np.ndarray:
"""Build diagonal weight vector: 1.0 for translation, QUAT_WEIGHT for quaternion.
Returns a 1-D array of length ``params.n_free()``.
"""
free = params.free_names()
w = np.ones(len(free))
quat_suffixes = ("/qw", "/qx", "/qy", "/qz")
for i, name in enumerate(free):
if any(name.endswith(s) for s in quat_suffixes):
w[i] = QUAT_WEIGHT
return w
def _angle_half_space(
obj: AngleConstraint,
constraint_idx: int,
env: dict[str, float],
params: ParamTable,
) -> HalfSpace | None:
"""Half-space for Angle: track sign of sin(angle) via cross product.
For angle constraints, the dot product is fixed (= cos(angle)),
but sin can be +/-. We track the cross product magnitude sign.
"""
if abs(obj.angle) < 1e-14 or abs(obj.angle - math.pi) < 1e-14:
return None # 0 or 180 degrees — no branch ambiguity
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
cx, cy, cz = cross3(z_i, z_j)
# Use the magnitude of the cross product's z-component as indicator
# (or whichever component is largest at setup time)
cx_val = cx.eval(env)
cy_val = cy.eval(env)
cz_val = cz.eval(env)
# Pick the dominant cross product component
components = [
(abs(cx_val), cx, cx_val),
(abs(cy_val), cy, cy_val),
(abs(cz_val), cz, cz_val),
]
_, best_expr, best_val = max(components, key=lambda t: t[0])
if abs(best_val) < 1e-14:
return None
def indicator(e: dict[str, float]) -> float:
return best_expr.eval(e)
ref_sign = math.copysign(1.0, best_val)
return HalfSpace(
constraint_index=constraint_idx,
reference_sign=ref_sign,
indicator_fn=indicator,
)
def _perpendicular_half_space(
obj: PerpendicularConstraint,
constraint_idx: int,
env: dict[str, float],
params: ParamTable,
) -> HalfSpace | None:
"""Half-space for Perpendicular: track which quadrant.
The dot product is constrained to 0, but the cross product sign
distinguishes which "side" of perpendicular.
"""
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
cx, cy, cz = cross3(z_i, z_j)
# Pick the dominant cross product component
cx_val = cx.eval(env)
cy_val = cy.eval(env)
cz_val = cz.eval(env)
components = [
(abs(cx_val), cx, cx_val),
(abs(cy_val), cy, cy_val),
(abs(cz_val), cz, cz_val),
]
_, best_expr, best_val = max(components, key=lambda t: t[0])
if abs(best_val) < 1e-14:
return None
def indicator(e: dict[str, float]) -> float:
return best_expr.eval(e)
ref_sign = math.copysign(1.0, best_val)
return HalfSpace(
constraint_index=constraint_idx,
reference_sign=ref_sign,
indicator_fn=indicator,
)