- Add per-entity DOF analysis via Jacobian SVD (diagnostics.py) - Add overconstrained detection: redundant vs conflicting constraints - Add half-space tracking to preserve configuration branch (preference.py) - Add minimum-movement weighting for least-squares solve - Extend BFGS fallback with weight vector and quaternion renormalization - Add snapshot/restore and env accessor to ParamTable - Fix DistancePointPointConstraint sign for half-space tracking
326 lines
10 KiB
Python
326 lines
10 KiB
Python
"""Solution preference: half-space tracking and minimum-movement weighting.
|
|
|
|
Half-space tracking preserves the initial configuration branch across
|
|
Newton iterations. For constraints with multiple valid solutions
|
|
(e.g. distance can be satisfied on either side), we record which
|
|
"half-space" the initial state lives in and correct the solver step
|
|
if it crosses to the wrong branch.
|
|
|
|
Minimum-movement weighting scales the Newton/BFGS step so that
|
|
quaternion parameters (rotation) are penalised more than translation
|
|
parameters, yielding the physically-nearest solution.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
from dataclasses import dataclass, field
|
|
from typing import Callable, List
|
|
|
|
import numpy as np
|
|
|
|
from .constraints import (
|
|
AngleConstraint,
|
|
ConstraintBase,
|
|
DistancePointPointConstraint,
|
|
ParallelConstraint,
|
|
PerpendicularConstraint,
|
|
)
|
|
from .geometry import cross3, dot3, marker_z_axis
|
|
from .params import ParamTable
|
|
|
|
|
|
@dataclass
|
|
class HalfSpace:
|
|
"""Tracks which branch of a branching constraint the solution should stay in."""
|
|
|
|
constraint_index: int # index in ctx.constraints
|
|
reference_sign: float # +1.0 or -1.0, captured at setup
|
|
indicator_fn: Callable[[dict[str, float]], float] # returns signed value
|
|
param_names: list[str] = field(default_factory=list) # params to flip
|
|
correction_fn: Callable[[ParamTable, float], None] | None = None
|
|
|
|
|
|
def compute_half_spaces(
|
|
constraint_objs: list[ConstraintBase],
|
|
constraint_indices: list[int],
|
|
params: ParamTable,
|
|
) -> list[HalfSpace]:
|
|
"""Build half-space trackers for all branching constraints.
|
|
|
|
Evaluates each constraint's indicator function at the current
|
|
parameter values to capture the reference sign.
|
|
"""
|
|
env = params.get_env()
|
|
half_spaces: list[HalfSpace] = []
|
|
|
|
for i, obj in enumerate(constraint_objs):
|
|
hs = _build_half_space(obj, constraint_indices[i], env, params)
|
|
if hs is not None:
|
|
half_spaces.append(hs)
|
|
|
|
return half_spaces
|
|
|
|
|
|
def apply_half_space_correction(
|
|
params: ParamTable,
|
|
half_spaces: list[HalfSpace],
|
|
) -> None:
|
|
"""Check each half-space and correct if the solver crossed a branch.
|
|
|
|
Called as a post_step callback from newton_solve.
|
|
"""
|
|
if not half_spaces:
|
|
return
|
|
|
|
env = params.get_env()
|
|
for hs in half_spaces:
|
|
current_val = hs.indicator_fn(env)
|
|
current_sign = (
|
|
math.copysign(1.0, current_val)
|
|
if abs(current_val) > 1e-14
|
|
else hs.reference_sign
|
|
)
|
|
if current_sign != hs.reference_sign and hs.correction_fn is not None:
|
|
hs.correction_fn(params, current_val)
|
|
# Re-read env after correction for subsequent half-spaces
|
|
env = params.get_env()
|
|
|
|
|
|
def _build_half_space(
|
|
obj: ConstraintBase,
|
|
constraint_idx: int,
|
|
env: dict[str, float],
|
|
params: ParamTable,
|
|
) -> HalfSpace | None:
|
|
"""Build a HalfSpace for a branching constraint, or None if not branching."""
|
|
|
|
if isinstance(obj, DistancePointPointConstraint) and obj.distance > 0:
|
|
return _distance_half_space(obj, constraint_idx, env, params)
|
|
|
|
if isinstance(obj, ParallelConstraint):
|
|
return _parallel_half_space(obj, constraint_idx, env, params)
|
|
|
|
if isinstance(obj, AngleConstraint):
|
|
return _angle_half_space(obj, constraint_idx, env, params)
|
|
|
|
if isinstance(obj, PerpendicularConstraint):
|
|
return _perpendicular_half_space(obj, constraint_idx, env, params)
|
|
|
|
return None
|
|
|
|
|
|
def _distance_half_space(
|
|
obj: DistancePointPointConstraint,
|
|
constraint_idx: int,
|
|
env: dict[str, float],
|
|
params: ParamTable,
|
|
) -> HalfSpace | None:
|
|
"""Half-space for DistancePointPoint: track displacement direction.
|
|
|
|
The indicator is the dot product of the current displacement with
|
|
the reference displacement direction. If the solver flips to the
|
|
opposite side, we reflect the moving body's position.
|
|
"""
|
|
p_i, p_j = obj.world_points()
|
|
|
|
# Evaluate reference displacement direction
|
|
dx = p_i[0].eval(env) - p_j[0].eval(env)
|
|
dy = p_i[1].eval(env) - p_j[1].eval(env)
|
|
dz = p_i[2].eval(env) - p_j[2].eval(env)
|
|
dist = math.sqrt(dx * dx + dy * dy + dz * dz)
|
|
|
|
if dist < 1e-14:
|
|
return None # points coincident, no branch to track
|
|
|
|
# Reference unit direction
|
|
nx, ny, nz = dx / dist, dy / dist, dz / dist
|
|
|
|
# Build indicator: dot(displacement, reference_direction)
|
|
# Use Expr evaluation for speed
|
|
disp_x, disp_y, disp_z = p_i[0] - p_j[0], p_i[1] - p_j[1], p_i[2] - p_j[2]
|
|
|
|
def indicator(e: dict[str, float]) -> float:
|
|
return disp_x.eval(e) * nx + disp_y.eval(e) * ny + disp_z.eval(e) * nz
|
|
|
|
ref_sign = math.copysign(1.0, indicator(env))
|
|
|
|
# Correction: reflect body_j position along reference direction
|
|
# (or body_i if body_j is grounded)
|
|
moving_body = obj.body_j if not obj.body_j.grounded else obj.body_i
|
|
if moving_body.grounded:
|
|
return None # both grounded, nothing to correct
|
|
|
|
px_name = f"{moving_body.part_id}/tx"
|
|
py_name = f"{moving_body.part_id}/ty"
|
|
pz_name = f"{moving_body.part_id}/tz"
|
|
|
|
sign_flip = -1.0 if moving_body is obj.body_j else 1.0
|
|
|
|
def correction(p: ParamTable, _val: float) -> None:
|
|
# Reflect displacement: negate the component along reference direction
|
|
e = p.get_env()
|
|
cur_dx = disp_x.eval(e)
|
|
cur_dy = disp_y.eval(e)
|
|
cur_dz = disp_z.eval(e)
|
|
# Project displacement onto reference direction
|
|
proj = cur_dx * nx + cur_dy * ny + cur_dz * nz
|
|
# Reflect: subtract 2*proj*n from the moving body's position
|
|
if not p.is_fixed(px_name):
|
|
p.set_value(px_name, p.get_value(px_name) + sign_flip * 2.0 * proj * nx)
|
|
if not p.is_fixed(py_name):
|
|
p.set_value(py_name, p.get_value(py_name) + sign_flip * 2.0 * proj * ny)
|
|
if not p.is_fixed(pz_name):
|
|
p.set_value(pz_name, p.get_value(pz_name) + sign_flip * 2.0 * proj * nz)
|
|
|
|
return HalfSpace(
|
|
constraint_index=constraint_idx,
|
|
reference_sign=ref_sign,
|
|
indicator_fn=indicator,
|
|
param_names=[px_name, py_name, pz_name],
|
|
correction_fn=correction,
|
|
)
|
|
|
|
|
|
def _parallel_half_space(
|
|
obj: ParallelConstraint,
|
|
constraint_idx: int,
|
|
env: dict[str, float],
|
|
params: ParamTable,
|
|
) -> HalfSpace:
|
|
"""Half-space for Parallel: track same-direction vs opposite-direction.
|
|
|
|
Indicator: dot(z_i, z_j). Positive = same direction, negative = opposite.
|
|
"""
|
|
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
|
|
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
|
|
dot_expr = dot3(z_i, z_j)
|
|
|
|
def indicator(e: dict[str, float]) -> float:
|
|
return dot_expr.eval(e)
|
|
|
|
ref_val = indicator(env)
|
|
ref_sign = math.copysign(1.0, ref_val) if abs(ref_val) > 1e-14 else 1.0
|
|
|
|
# No geometric correction — just let the indicator track.
|
|
# The Newton solver naturally handles this via the cross-product residual.
|
|
# We only need to detect and report branch flips.
|
|
return HalfSpace(
|
|
constraint_index=constraint_idx,
|
|
reference_sign=ref_sign,
|
|
indicator_fn=indicator,
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# Minimum-movement weighting
|
|
# ============================================================================
|
|
|
|
# Scale factor so that a 1-radian rotation is penalised as much as a
|
|
# (180/pi)-unit translation. This makes the weighted minimum-norm
|
|
# step prefer translating over rotating for the same residual reduction.
|
|
QUAT_WEIGHT = (180.0 / math.pi) ** 2 # ~3283
|
|
|
|
|
|
def build_weight_vector(params: ParamTable) -> np.ndarray:
|
|
"""Build diagonal weight vector: 1.0 for translation, QUAT_WEIGHT for quaternion.
|
|
|
|
Returns a 1-D array of length ``params.n_free()``.
|
|
"""
|
|
free = params.free_names()
|
|
w = np.ones(len(free))
|
|
quat_suffixes = ("/qw", "/qx", "/qy", "/qz")
|
|
for i, name in enumerate(free):
|
|
if any(name.endswith(s) for s in quat_suffixes):
|
|
w[i] = QUAT_WEIGHT
|
|
return w
|
|
|
|
|
|
def _angle_half_space(
|
|
obj: AngleConstraint,
|
|
constraint_idx: int,
|
|
env: dict[str, float],
|
|
params: ParamTable,
|
|
) -> HalfSpace | None:
|
|
"""Half-space for Angle: track sign of sin(angle) via cross product.
|
|
|
|
For angle constraints, the dot product is fixed (= cos(angle)),
|
|
but sin can be +/-. We track the cross product magnitude sign.
|
|
"""
|
|
if abs(obj.angle) < 1e-14 or abs(obj.angle - math.pi) < 1e-14:
|
|
return None # 0 or 180 degrees — no branch ambiguity
|
|
|
|
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
|
|
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
|
|
cx, cy, cz = cross3(z_i, z_j)
|
|
|
|
# Use the magnitude of the cross product's z-component as indicator
|
|
# (or whichever component is largest at setup time)
|
|
cx_val = cx.eval(env)
|
|
cy_val = cy.eval(env)
|
|
cz_val = cz.eval(env)
|
|
|
|
# Pick the dominant cross product component
|
|
components = [
|
|
(abs(cx_val), cx, cx_val),
|
|
(abs(cy_val), cy, cy_val),
|
|
(abs(cz_val), cz, cz_val),
|
|
]
|
|
_, best_expr, best_val = max(components, key=lambda t: t[0])
|
|
|
|
if abs(best_val) < 1e-14:
|
|
return None
|
|
|
|
def indicator(e: dict[str, float]) -> float:
|
|
return best_expr.eval(e)
|
|
|
|
ref_sign = math.copysign(1.0, best_val)
|
|
|
|
return HalfSpace(
|
|
constraint_index=constraint_idx,
|
|
reference_sign=ref_sign,
|
|
indicator_fn=indicator,
|
|
)
|
|
|
|
|
|
def _perpendicular_half_space(
|
|
obj: PerpendicularConstraint,
|
|
constraint_idx: int,
|
|
env: dict[str, float],
|
|
params: ParamTable,
|
|
) -> HalfSpace | None:
|
|
"""Half-space for Perpendicular: track which quadrant.
|
|
|
|
The dot product is constrained to 0, but the cross product sign
|
|
distinguishes which "side" of perpendicular.
|
|
"""
|
|
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
|
|
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
|
|
cx, cy, cz = cross3(z_i, z_j)
|
|
|
|
# Pick the dominant cross product component
|
|
cx_val = cx.eval(env)
|
|
cy_val = cy.eval(env)
|
|
cz_val = cz.eval(env)
|
|
|
|
components = [
|
|
(abs(cx_val), cx, cx_val),
|
|
(abs(cy_val), cy, cy_val),
|
|
(abs(cz_val), cz, cz_val),
|
|
]
|
|
_, best_expr, best_val = max(components, key=lambda t: t[0])
|
|
|
|
if abs(best_val) < 1e-14:
|
|
return None
|
|
|
|
def indicator(e: dict[str, float]) -> float:
|
|
return best_expr.eval(e)
|
|
|
|
ref_sign = math.copysign(1.0, best_val)
|
|
|
|
return HalfSpace(
|
|
constraint_index=constraint_idx,
|
|
reference_sign=ref_sign,
|
|
indicator_fn=indicator,
|
|
)
|