Files
solver/kindred_solver/preference.py
forbes-0023 c2ebcc3169 fix(solver): prevent orientation flips during interactive drag
Add half-space tracking for all compound constraints with branch
ambiguity: Planar, Revolute, Concentric, Cylindrical, Slider, Screw,
Universal, PointInPlane, and LineInPlane.  Previously only
DistancePointPoint, Parallel, Angle, and Perpendicular were tracked,
so the Newton-Raphson solver could converge to the wrong branch for
compound constraints — causing parts to drift through plane
constraints while honoring revolute joints.

Add quaternion continuity enforcement in drag_step(): after solving,
each non-dragged body's quaternion is checked against its pre-step
value and negated if in the opposite hemisphere (standard SLERP
short-arc correction).  This prevents the C++ validateNewPlacements()
from rejecting valid solutions as 'flipped orientation' due to the
quaternion double-cover ambiguity (q and -q encode the same rotation
but measure as ~340° apart).
2026-02-24 20:46:42 -06:00

668 lines
22 KiB
Python

"""Solution preference: half-space tracking and minimum-movement weighting.
Half-space tracking preserves the initial configuration branch across
Newton iterations. For constraints with multiple valid solutions
(e.g. distance can be satisfied on either side), we record which
"half-space" the initial state lives in and correct the solver step
if it crosses to the wrong branch.
Minimum-movement weighting scales the Newton/BFGS step so that
quaternion parameters (rotation) are penalised more than translation
parameters, yielding the physically-nearest solution.
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from typing import Callable, List
import numpy as np
from .constraints import (
AngleConstraint,
ConcentricConstraint,
ConstraintBase,
CylindricalConstraint,
DistancePointPointConstraint,
LineInPlaneConstraint,
ParallelConstraint,
PerpendicularConstraint,
PlanarConstraint,
PointInPlaneConstraint,
RevoluteConstraint,
ScrewConstraint,
SliderConstraint,
UniversalConstraint,
)
from .geometry import cross3, dot3, marker_z_axis, point_plane_distance, sub3
from .params import ParamTable
@dataclass
class HalfSpace:
"""Tracks which branch of a branching constraint the solution should stay in."""
constraint_index: int # index in ctx.constraints
reference_sign: float # +1.0 or -1.0, captured at setup
indicator_fn: Callable[[dict[str, float]], float] # returns signed value
param_names: list[str] = field(default_factory=list) # params to flip
correction_fn: Callable[[ParamTable, float], None] | None = None
def compute_half_spaces(
constraint_objs: list[ConstraintBase],
constraint_indices: list[int],
params: ParamTable,
) -> list[HalfSpace]:
"""Build half-space trackers for all branching constraints.
Evaluates each constraint's indicator function at the current
parameter values to capture the reference sign.
"""
env = params.get_env()
half_spaces: list[HalfSpace] = []
for i, obj in enumerate(constraint_objs):
hs = _build_half_space(obj, constraint_indices[i], env, params)
if hs is not None:
half_spaces.append(hs)
return half_spaces
def apply_half_space_correction(
params: ParamTable,
half_spaces: list[HalfSpace],
) -> None:
"""Check each half-space and correct if the solver crossed a branch.
Called as a post_step callback from newton_solve.
"""
if not half_spaces:
return
env = params.get_env()
for hs in half_spaces:
current_val = hs.indicator_fn(env)
current_sign = (
math.copysign(1.0, current_val)
if abs(current_val) > 1e-14
else hs.reference_sign
)
if current_sign != hs.reference_sign and hs.correction_fn is not None:
hs.correction_fn(params, current_val)
# Re-read env after correction for subsequent half-spaces
env = params.get_env()
def _build_half_space(
obj: ConstraintBase,
constraint_idx: int,
env: dict[str, float],
params: ParamTable,
) -> HalfSpace | None:
"""Build a HalfSpace for a branching constraint, or None if not branching."""
if isinstance(obj, DistancePointPointConstraint) and obj.distance > 0:
return _distance_half_space(obj, constraint_idx, env, params)
if isinstance(obj, ParallelConstraint):
return _parallel_half_space(obj, constraint_idx, env, params)
if isinstance(obj, AngleConstraint):
return _angle_half_space(obj, constraint_idx, env, params)
if isinstance(obj, PerpendicularConstraint):
return _perpendicular_half_space(obj, constraint_idx, env, params)
if isinstance(obj, PlanarConstraint):
return _planar_half_space(obj, constraint_idx, env, params)
if isinstance(obj, RevoluteConstraint):
return _revolute_half_space(obj, constraint_idx, env, params)
if isinstance(obj, ConcentricConstraint):
return _concentric_half_space(obj, constraint_idx, env, params)
if isinstance(obj, PointInPlaneConstraint):
return _point_in_plane_half_space(obj, constraint_idx, env, params)
if isinstance(obj, LineInPlaneConstraint):
return _line_in_plane_half_space(obj, constraint_idx, env, params)
if isinstance(obj, CylindricalConstraint):
return _axis_direction_half_space(obj, constraint_idx, env)
if isinstance(obj, SliderConstraint):
return _axis_direction_half_space(obj, constraint_idx, env)
if isinstance(obj, ScrewConstraint):
return _axis_direction_half_space(obj, constraint_idx, env)
if isinstance(obj, UniversalConstraint):
return _universal_half_space(obj, constraint_idx, env)
return None
def _distance_half_space(
obj: DistancePointPointConstraint,
constraint_idx: int,
env: dict[str, float],
params: ParamTable,
) -> HalfSpace | None:
"""Half-space for DistancePointPoint: track displacement direction.
The indicator is the dot product of the current displacement with
the reference displacement direction. If the solver flips to the
opposite side, we reflect the moving body's position.
"""
p_i, p_j = obj.world_points()
# Evaluate reference displacement direction
dx = p_i[0].eval(env) - p_j[0].eval(env)
dy = p_i[1].eval(env) - p_j[1].eval(env)
dz = p_i[2].eval(env) - p_j[2].eval(env)
dist = math.sqrt(dx * dx + dy * dy + dz * dz)
if dist < 1e-14:
return None # points coincident, no branch to track
# Reference unit direction
nx, ny, nz = dx / dist, dy / dist, dz / dist
# Build indicator: dot(displacement, reference_direction)
# Use Expr evaluation for speed
disp_x, disp_y, disp_z = p_i[0] - p_j[0], p_i[1] - p_j[1], p_i[2] - p_j[2]
def indicator(e: dict[str, float]) -> float:
return disp_x.eval(e) * nx + disp_y.eval(e) * ny + disp_z.eval(e) * nz
ref_sign = math.copysign(1.0, indicator(env))
# Correction: reflect body_j position along reference direction
# (or body_i if body_j is grounded)
moving_body = obj.body_j if not obj.body_j.grounded else obj.body_i
if moving_body.grounded:
return None # both grounded, nothing to correct
px_name = f"{moving_body.part_id}/tx"
py_name = f"{moving_body.part_id}/ty"
pz_name = f"{moving_body.part_id}/tz"
sign_flip = -1.0 if moving_body is obj.body_j else 1.0
def correction(p: ParamTable, _val: float) -> None:
# Reflect displacement: negate the component along reference direction
e = p.get_env()
cur_dx = disp_x.eval(e)
cur_dy = disp_y.eval(e)
cur_dz = disp_z.eval(e)
# Project displacement onto reference direction
proj = cur_dx * nx + cur_dy * ny + cur_dz * nz
# Reflect: subtract 2*proj*n from the moving body's position
if not p.is_fixed(px_name):
p.set_value(px_name, p.get_value(px_name) + sign_flip * 2.0 * proj * nx)
if not p.is_fixed(py_name):
p.set_value(py_name, p.get_value(py_name) + sign_flip * 2.0 * proj * ny)
if not p.is_fixed(pz_name):
p.set_value(pz_name, p.get_value(pz_name) + sign_flip * 2.0 * proj * nz)
return HalfSpace(
constraint_index=constraint_idx,
reference_sign=ref_sign,
indicator_fn=indicator,
param_names=[px_name, py_name, pz_name],
correction_fn=correction,
)
def _parallel_half_space(
obj: ParallelConstraint,
constraint_idx: int,
env: dict[str, float],
params: ParamTable,
) -> HalfSpace:
"""Half-space for Parallel: track same-direction vs opposite-direction.
Indicator: dot(z_i, z_j). Positive = same direction, negative = opposite.
"""
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
dot_expr = dot3(z_i, z_j)
def indicator(e: dict[str, float]) -> float:
return dot_expr.eval(e)
ref_val = indicator(env)
ref_sign = math.copysign(1.0, ref_val) if abs(ref_val) > 1e-14 else 1.0
# No geometric correction — just let the indicator track.
# The Newton solver naturally handles this via the cross-product residual.
# We only need to detect and report branch flips.
return HalfSpace(
constraint_index=constraint_idx,
reference_sign=ref_sign,
indicator_fn=indicator,
)
def _planar_half_space(
obj: PlanarConstraint,
constraint_idx: int,
env: dict[str, float],
params: ParamTable,
) -> HalfSpace | None:
"""Half-space for Planar: track which side of the plane the point is on
AND which direction the normals face.
A Planar constraint has parallel normals (cross product = 0) plus
point-in-plane (signed distance = 0). Both have branch ambiguity:
the normals can be same-direction or opposite, and the point can
approach the plane from either side. We track the signed distance
from marker_i to the plane defined by marker_j — this captures
the plane-side and is the primary drift mode during drag.
"""
# Point-in-plane signed distance as indicator
p_i = obj.body_i.world_point(*obj.marker_i_pos)
p_j = obj.body_j.world_point(*obj.marker_j_pos)
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
d_expr = point_plane_distance(p_i, p_j, z_j)
d_val = d_expr.eval(env)
# Also track normal alignment (same as Parallel half-space)
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
dot_expr = dot3(z_i, z_j)
dot_val = dot_expr.eval(env)
normal_ref_sign = math.copysign(1.0, dot_val) if abs(dot_val) > 1e-14 else 1.0
# If offset is zero and distance is near-zero, we still need the normal
# direction indicator to prevent flipping through the plane.
# Use the normal dot product as the primary indicator when the point
# is already on the plane (distance ≈ 0).
if abs(d_val) < 1e-10:
# Point is on the plane — track normal direction instead
def indicator(e: dict[str, float]) -> float:
return dot_expr.eval(e)
ref_sign = normal_ref_sign
else:
# Point is off-plane — track which side
def indicator(e: dict[str, float]) -> float:
return d_expr.eval(e)
ref_sign = math.copysign(1.0, d_val)
# Correction: reflect the moving body's position through the plane
moving_body = obj.body_j if not obj.body_j.grounded else obj.body_i
if moving_body.grounded:
return None
px_name = f"{moving_body.part_id}/tx"
py_name = f"{moving_body.part_id}/ty"
pz_name = f"{moving_body.part_id}/tz"
def correction(p: ParamTable, _val: float) -> None:
e = p.get_env()
# Recompute signed distance and normal direction
d_cur = d_expr.eval(e)
nx = z_j[0].eval(e)
ny = z_j[1].eval(e)
nz = z_j[2].eval(e)
n_len = math.sqrt(nx * nx + ny * ny + nz * nz)
if n_len < 1e-15:
return
nx, ny, nz = nx / n_len, ny / n_len, nz / n_len
# Reflect through plane: move body by -2*d along normal
sign = -1.0 if moving_body is obj.body_j else 1.0
if not p.is_fixed(px_name):
p.set_value(px_name, p.get_value(px_name) + sign * 2.0 * d_cur * nx)
if not p.is_fixed(py_name):
p.set_value(py_name, p.get_value(py_name) + sign * 2.0 * d_cur * ny)
if not p.is_fixed(pz_name):
p.set_value(pz_name, p.get_value(pz_name) + sign * 2.0 * d_cur * nz)
return HalfSpace(
constraint_index=constraint_idx,
reference_sign=ref_sign,
indicator_fn=indicator,
correction_fn=correction,
)
def _revolute_half_space(
obj: RevoluteConstraint,
constraint_idx: int,
env: dict[str, float],
params: ParamTable,
) -> HalfSpace | None:
"""Half-space for Revolute: track hinge axis direction.
A revolute has coincident origins + parallel Z-axes. The parallel
axes can flip direction (same ambiguity as Parallel). Track
dot(z_i, z_j) to prevent the axis from inverting.
"""
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
dot_expr = dot3(z_i, z_j)
ref_val = dot_expr.eval(env)
ref_sign = math.copysign(1.0, ref_val) if abs(ref_val) > 1e-14 else 1.0
return HalfSpace(
constraint_index=constraint_idx,
reference_sign=ref_sign,
indicator_fn=lambda e: dot_expr.eval(e),
)
def _concentric_half_space(
obj: ConcentricConstraint,
constraint_idx: int,
env: dict[str, float],
params: ParamTable,
) -> HalfSpace | None:
"""Half-space for Concentric: track axis direction.
Concentric has parallel axes + point-on-line. The parallel axes
can flip direction. Track dot(z_i, z_j).
"""
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
dot_expr = dot3(z_i, z_j)
ref_val = dot_expr.eval(env)
ref_sign = math.copysign(1.0, ref_val) if abs(ref_val) > 1e-14 else 1.0
return HalfSpace(
constraint_index=constraint_idx,
reference_sign=ref_sign,
indicator_fn=lambda e: dot_expr.eval(e),
)
def _point_in_plane_half_space(
obj: PointInPlaneConstraint,
constraint_idx: int,
env: dict[str, float],
params: ParamTable,
) -> HalfSpace | None:
"""Half-space for PointInPlane: track which side of the plane.
The signed distance to the plane can be satisfied from either side.
Track which side the initial configuration is on.
"""
p_i = obj.body_i.world_point(*obj.marker_i_pos)
p_j = obj.body_j.world_point(*obj.marker_j_pos)
n_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
d_expr = point_plane_distance(p_i, p_j, n_j)
d_val = d_expr.eval(env)
if abs(d_val) < 1e-10:
return None # already on the plane, no branch to track
ref_sign = math.copysign(1.0, d_val)
moving_body = obj.body_j if not obj.body_j.grounded else obj.body_i
if moving_body.grounded:
return None
px_name = f"{moving_body.part_id}/tx"
py_name = f"{moving_body.part_id}/ty"
pz_name = f"{moving_body.part_id}/tz"
def correction(p: ParamTable, _val: float) -> None:
e = p.get_env()
d_cur = d_expr.eval(e)
nx = n_j[0].eval(e)
ny = n_j[1].eval(e)
nz = n_j[2].eval(e)
n_len = math.sqrt(nx * nx + ny * ny + nz * nz)
if n_len < 1e-15:
return
nx, ny, nz = nx / n_len, ny / n_len, nz / n_len
sign = -1.0 if moving_body is obj.body_j else 1.0
if not p.is_fixed(px_name):
p.set_value(px_name, p.get_value(px_name) + sign * 2.0 * d_cur * nx)
if not p.is_fixed(py_name):
p.set_value(py_name, p.get_value(py_name) + sign * 2.0 * d_cur * ny)
if not p.is_fixed(pz_name):
p.set_value(pz_name, p.get_value(pz_name) + sign * 2.0 * d_cur * nz)
return HalfSpace(
constraint_index=constraint_idx,
reference_sign=ref_sign,
indicator_fn=lambda e: d_expr.eval(e),
correction_fn=correction,
)
def _line_in_plane_half_space(
obj: LineInPlaneConstraint,
constraint_idx: int,
env: dict[str, float],
params: ParamTable,
) -> HalfSpace | None:
"""Half-space for LineInPlane: track which side of the plane.
Same plane-side ambiguity as PointInPlane.
"""
p_i = obj.body_i.world_point(*obj.marker_i_pos)
p_j = obj.body_j.world_point(*obj.marker_j_pos)
n_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
d_expr = point_plane_distance(p_i, p_j, n_j)
d_val = d_expr.eval(env)
if abs(d_val) < 1e-10:
return None
ref_sign = math.copysign(1.0, d_val)
moving_body = obj.body_j if not obj.body_j.grounded else obj.body_i
if moving_body.grounded:
return None
px_name = f"{moving_body.part_id}/tx"
py_name = f"{moving_body.part_id}/ty"
pz_name = f"{moving_body.part_id}/tz"
def correction(p: ParamTable, _val: float) -> None:
e = p.get_env()
d_cur = d_expr.eval(e)
nx = n_j[0].eval(e)
ny = n_j[1].eval(e)
nz = n_j[2].eval(e)
n_len = math.sqrt(nx * nx + ny * ny + nz * nz)
if n_len < 1e-15:
return
nx, ny, nz = nx / n_len, ny / n_len, nz / n_len
sign = -1.0 if moving_body is obj.body_j else 1.0
if not p.is_fixed(px_name):
p.set_value(px_name, p.get_value(px_name) + sign * 2.0 * d_cur * nx)
if not p.is_fixed(py_name):
p.set_value(py_name, p.get_value(py_name) + sign * 2.0 * d_cur * ny)
if not p.is_fixed(pz_name):
p.set_value(pz_name, p.get_value(pz_name) + sign * 2.0 * d_cur * nz)
return HalfSpace(
constraint_index=constraint_idx,
reference_sign=ref_sign,
indicator_fn=lambda e: d_expr.eval(e),
correction_fn=correction,
)
def _axis_direction_half_space(
obj,
constraint_idx: int,
env: dict[str, float],
) -> HalfSpace | None:
"""Half-space for any constraint with parallel Z-axes (Cylindrical, Slider, Screw).
Tracks dot(z_i, z_j) to prevent axis inversion.
"""
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
dot_expr = dot3(z_i, z_j)
ref_val = dot_expr.eval(env)
ref_sign = math.copysign(1.0, ref_val) if abs(ref_val) > 1e-14 else 1.0
return HalfSpace(
constraint_index=constraint_idx,
reference_sign=ref_sign,
indicator_fn=lambda e: dot_expr.eval(e),
)
def _universal_half_space(
obj: UniversalConstraint,
constraint_idx: int,
env: dict[str, float],
) -> HalfSpace | None:
"""Half-space for Universal: track which quadrant of perpendicularity.
Universal has dot(z_i, z_j) = 0 (perpendicular). The cross product
sign distinguishes which "side" of perpendicular.
"""
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
cx, cy, cz = cross3(z_i, z_j)
cx_val = cx.eval(env)
cy_val = cy.eval(env)
cz_val = cz.eval(env)
components = [
(abs(cx_val), cx, cx_val),
(abs(cy_val), cy, cy_val),
(abs(cz_val), cz, cz_val),
]
_, best_expr, best_val = max(components, key=lambda t: t[0])
if abs(best_val) < 1e-14:
return None
ref_sign = math.copysign(1.0, best_val)
return HalfSpace(
constraint_index=constraint_idx,
reference_sign=ref_sign,
indicator_fn=lambda e: best_expr.eval(e),
)
# ============================================================================
# Minimum-movement weighting
# ============================================================================
# Scale factor so that a 1-radian rotation is penalised as much as a
# (180/pi)-unit translation. This makes the weighted minimum-norm
# step prefer translating over rotating for the same residual reduction.
QUAT_WEIGHT = (180.0 / math.pi) ** 2 # ~3283
def build_weight_vector(params: ParamTable) -> np.ndarray:
"""Build diagonal weight vector: 1.0 for translation, QUAT_WEIGHT for quaternion.
Returns a 1-D array of length ``params.n_free()``.
"""
free = params.free_names()
w = np.ones(len(free))
quat_suffixes = ("/qw", "/qx", "/qy", "/qz")
for i, name in enumerate(free):
if any(name.endswith(s) for s in quat_suffixes):
w[i] = QUAT_WEIGHT
return w
def _angle_half_space(
obj: AngleConstraint,
constraint_idx: int,
env: dict[str, float],
params: ParamTable,
) -> HalfSpace | None:
"""Half-space for Angle: track sign of sin(angle) via cross product.
For angle constraints, the dot product is fixed (= cos(angle)),
but sin can be +/-. We track the cross product magnitude sign.
"""
if abs(obj.angle) < 1e-14 or abs(obj.angle - math.pi) < 1e-14:
return None # 0 or 180 degrees — no branch ambiguity
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
cx, cy, cz = cross3(z_i, z_j)
# Use the magnitude of the cross product's z-component as indicator
# (or whichever component is largest at setup time)
cx_val = cx.eval(env)
cy_val = cy.eval(env)
cz_val = cz.eval(env)
# Pick the dominant cross product component
components = [
(abs(cx_val), cx, cx_val),
(abs(cy_val), cy, cy_val),
(abs(cz_val), cz, cz_val),
]
_, best_expr, best_val = max(components, key=lambda t: t[0])
if abs(best_val) < 1e-14:
return None
def indicator(e: dict[str, float]) -> float:
return best_expr.eval(e)
ref_sign = math.copysign(1.0, best_val)
return HalfSpace(
constraint_index=constraint_idx,
reference_sign=ref_sign,
indicator_fn=indicator,
)
def _perpendicular_half_space(
obj: PerpendicularConstraint,
constraint_idx: int,
env: dict[str, float],
params: ParamTable,
) -> HalfSpace | None:
"""Half-space for Perpendicular: track which quadrant.
The dot product is constrained to 0, but the cross product sign
distinguishes which "side" of perpendicular.
"""
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
cx, cy, cz = cross3(z_i, z_j)
# Pick the dominant cross product component
cx_val = cx.eval(env)
cy_val = cy.eval(env)
cz_val = cz.eval(env)
components = [
(abs(cx_val), cx, cx_val),
(abs(cy_val), cy, cy_val),
(abs(cz_val), cz, cz_val),
]
_, best_expr, best_val = max(components, key=lambda t: t[0])
if abs(best_val) < 1e-14:
return None
def indicator(e: dict[str, float]) -> float:
return best_expr.eval(e)
ref_sign = math.copysign(1.0, best_val)
return HalfSpace(
constraint_index=constraint_idx,
reference_sign=ref_sign,
indicator_fn=indicator,
)