Compare commits
19 Commits
feat/addon
...
f85dc047e8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f85dc047e8 | ||
| 6c2ddb6494 | |||
| 5802d45a7f | |||
| 9d86bb203e | |||
|
|
c2ebcc3169 | ||
| e7e4266f3d | |||
|
|
0825578778 | ||
|
|
8e521b4519 | ||
|
|
bfb787157c | ||
|
|
e0468cd3c1 | ||
|
|
64b1e24467 | ||
|
|
d20b38e760 | ||
| 318a1c17da | |||
|
|
adaa0f9a69 | ||
|
|
9dad25e947 | ||
|
|
b4b8724ff1 | ||
| 3f5f7905b5 | |||
|
|
92ae57751f | ||
|
|
533ca91774 |
29
Init.py
29
Init.py
@@ -1,11 +1,40 @@
|
||||
"""Register the Kindred solver with the KCSolve solver registry."""
|
||||
|
||||
import logging
|
||||
|
||||
import FreeCAD
|
||||
|
||||
|
||||
class _FreeCADLogHandler(logging.Handler):
|
||||
"""Route Python logging to FreeCAD's Console."""
|
||||
|
||||
def emit(self, record):
|
||||
msg = self.format(record) + "\n"
|
||||
if record.levelno >= logging.ERROR:
|
||||
FreeCAD.Console.PrintError(msg)
|
||||
elif record.levelno >= logging.WARNING:
|
||||
FreeCAD.Console.PrintWarning(msg)
|
||||
elif record.levelno >= logging.INFO:
|
||||
FreeCAD.Console.PrintLog(msg)
|
||||
else:
|
||||
FreeCAD.Console.PrintLog(msg)
|
||||
|
||||
|
||||
def _setup_logging():
|
||||
"""Attach FreeCAD log handler to the kindred_solver logger."""
|
||||
logger = logging.getLogger("kindred_solver")
|
||||
if not logger.handlers:
|
||||
handler = _FreeCADLogHandler()
|
||||
handler.setFormatter(logging.Formatter("%(name)s: %(message)s"))
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
try:
|
||||
import kcsolve
|
||||
from kindred_solver import KindredSolver
|
||||
|
||||
_setup_logging()
|
||||
kcsolve.register_solver("kindred", KindredSolver)
|
||||
FreeCAD.Console.PrintLog("kindred-solver registered\n")
|
||||
except Exception as exc:
|
||||
|
||||
186
kindred_solver/bfgs.py
Normal file
186
kindred_solver/bfgs.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""L-BFGS-B fallback solver for when Newton-Raphson fails to converge.
|
||||
|
||||
Minimizes f(x) = 0.5 * sum(r_i(x)^2) using scipy's L-BFGS-B with
|
||||
analytic gradient from the Expr DAG's symbolic differentiation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .expr import Expr
|
||||
from .params import ParamTable
|
||||
|
||||
try:
|
||||
from scipy.optimize import minimize as _scipy_minimize
|
||||
|
||||
_HAS_SCIPY = True
|
||||
except ImportError:
|
||||
_HAS_SCIPY = False
|
||||
|
||||
|
||||
def bfgs_solve(
|
||||
residuals: List[Expr],
|
||||
params: ParamTable,
|
||||
quat_groups: List[tuple[str, str, str, str]] | None = None,
|
||||
max_iter: int = 200,
|
||||
tol: float = 1e-10,
|
||||
weight_vector: "np.ndarray | None" = None,
|
||||
jac_exprs: "List[List[Expr]] | None" = None,
|
||||
compiled_eval: "Callable | None" = None,
|
||||
) -> bool:
|
||||
"""Solve ``residuals == 0`` by minimizing sum of squared residuals.
|
||||
|
||||
Falls back gracefully to False if scipy is not available.
|
||||
|
||||
When *weight_vector* is provided, residuals are scaled by
|
||||
``sqrt(w)`` so that the objective becomes
|
||||
``0.5 * sum(w_i * r_i^2)`` — equivalent to weighted least-squares.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
jac_exprs:
|
||||
Pre-built symbolic Jacobian (list-of-lists of Expr).
|
||||
compiled_eval:
|
||||
Pre-compiled evaluation function from :mod:`codegen`.
|
||||
|
||||
Returns True if converged (||r|| < tol).
|
||||
"""
|
||||
if not _HAS_SCIPY:
|
||||
return False
|
||||
|
||||
free = params.free_names()
|
||||
n_free = len(free)
|
||||
n_res = len(residuals)
|
||||
|
||||
if n_free == 0 or n_res == 0:
|
||||
return True
|
||||
|
||||
# Build symbolic gradient expressions once: d(r_i)/d(x_j)
|
||||
if jac_exprs is None:
|
||||
jac_exprs = []
|
||||
for r in residuals:
|
||||
row = []
|
||||
for name in free:
|
||||
row.append(r.diff(name).simplify())
|
||||
jac_exprs.append(row)
|
||||
|
||||
# Try compilation if not provided
|
||||
if compiled_eval is None:
|
||||
from .codegen import try_compile_system
|
||||
|
||||
compiled_eval = try_compile_system(residuals, jac_exprs, n_res, n_free)
|
||||
|
||||
# Pre-compute scaling for weighted minimum-norm
|
||||
if weight_vector is not None:
|
||||
w_sqrt = np.sqrt(weight_vector)
|
||||
w_inv_sqrt = 1.0 / w_sqrt
|
||||
else:
|
||||
w_sqrt = None
|
||||
w_inv_sqrt = None
|
||||
|
||||
# Pre-allocate arrays reused across objective calls
|
||||
r_vals = np.empty(n_res)
|
||||
J = np.zeros((n_res, n_free))
|
||||
|
||||
def objective_and_grad(y_vec):
|
||||
# Transform back from scaled space if weighted
|
||||
if w_inv_sqrt is not None:
|
||||
x_vec = y_vec * w_inv_sqrt
|
||||
else:
|
||||
x_vec = y_vec
|
||||
|
||||
# Update params
|
||||
params.set_free_vector(x_vec)
|
||||
if quat_groups:
|
||||
_renormalize_quats(params, quat_groups)
|
||||
|
||||
if compiled_eval is not None:
|
||||
J[:] = 0.0
|
||||
compiled_eval(params.env_ref(), r_vals, J)
|
||||
else:
|
||||
env = params.get_env()
|
||||
for i, r in enumerate(residuals):
|
||||
r_vals[i] = r.eval(env)
|
||||
for i in range(n_res):
|
||||
for j in range(n_free):
|
||||
J[i, j] = jac_exprs[i][j].eval(env)
|
||||
|
||||
f = 0.5 * np.dot(r_vals, r_vals)
|
||||
|
||||
# Gradient of f w.r.t. x = J^T @ r
|
||||
grad_x = J.T @ r_vals
|
||||
|
||||
# Chain rule: df/dy = df/dx * dx/dy = grad_x * w_inv_sqrt
|
||||
if w_inv_sqrt is not None:
|
||||
grad = grad_x * w_inv_sqrt
|
||||
else:
|
||||
grad = grad_x
|
||||
|
||||
return f, grad
|
||||
|
||||
x0 = params.get_free_vector().copy()
|
||||
|
||||
# Transform initial guess to scaled space
|
||||
if w_sqrt is not None:
|
||||
y0 = x0 * w_sqrt
|
||||
else:
|
||||
y0 = x0
|
||||
|
||||
result = _scipy_minimize(
|
||||
objective_and_grad,
|
||||
y0,
|
||||
method="L-BFGS-B",
|
||||
jac=True,
|
||||
options={"maxiter": max_iter, "ftol": tol * tol, "gtol": tol},
|
||||
)
|
||||
|
||||
# Apply final result (transform back from scaled space)
|
||||
if w_inv_sqrt is not None:
|
||||
params.set_free_vector(result.x * w_inv_sqrt)
|
||||
else:
|
||||
params.set_free_vector(result.x)
|
||||
if quat_groups:
|
||||
_renormalize_quats(params, quat_groups)
|
||||
|
||||
# Check convergence on actual residual norm
|
||||
if compiled_eval is not None:
|
||||
compiled_eval(params.env_ref(), r_vals, J)
|
||||
else:
|
||||
env = params.get_env()
|
||||
for i, r in enumerate(residuals):
|
||||
r_vals[i] = r.eval(env)
|
||||
return bool(np.linalg.norm(r_vals) < tol)
|
||||
|
||||
|
||||
def _renormalize_quats(
|
||||
params: ParamTable,
|
||||
groups: List[tuple[str, str, str, str]],
|
||||
):
|
||||
"""Project quaternion params back onto the unit sphere."""
|
||||
for qw_name, qx_name, qy_name, qz_name in groups:
|
||||
if (
|
||||
params.is_fixed(qw_name)
|
||||
and params.is_fixed(qx_name)
|
||||
and params.is_fixed(qy_name)
|
||||
and params.is_fixed(qz_name)
|
||||
):
|
||||
continue
|
||||
w = params.get_value(qw_name)
|
||||
x = params.get_value(qx_name)
|
||||
y = params.get_value(qy_name)
|
||||
z = params.get_value(qz_name)
|
||||
norm = math.sqrt(w * w + x * x + y * y + z * z)
|
||||
if norm < 1e-15:
|
||||
params.set_value(qw_name, 1.0)
|
||||
params.set_value(qx_name, 0.0)
|
||||
params.set_value(qy_name, 0.0)
|
||||
params.set_value(qz_name, 0.0)
|
||||
else:
|
||||
params.set_value(qw_name, w / norm)
|
||||
params.set_value(qx_name, x / norm)
|
||||
params.set_value(qy_name, y / norm)
|
||||
params.set_value(qz_name, z / norm)
|
||||
308
kindred_solver/codegen.py
Normal file
308
kindred_solver/codegen.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""Compile Expr DAGs into flat Python functions for fast evaluation.
|
||||
|
||||
The compilation pipeline:
|
||||
1. Collect all Expr nodes to be evaluated (residuals + Jacobian entries).
|
||||
2. Identify common subexpressions (CSE) by ``id()`` — the Expr DAG
|
||||
already shares node objects via ParamTable's Var instances.
|
||||
3. Generate a single Python function body that computes CSE temps,
|
||||
then fills ``r_vec`` and ``J`` arrays in-place.
|
||||
4. Compile with ``compile()`` + ``exec()`` and return the callable.
|
||||
|
||||
The generated function signature is::
|
||||
|
||||
fn(env: dict[str, float], r_vec: ndarray, J: ndarray) -> None
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from collections import Counter
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .expr import Const, Expr, Var
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Namespace injected into compiled functions.
|
||||
_CODEGEN_NS = {
|
||||
"_sin": math.sin,
|
||||
"_cos": math.cos,
|
||||
"_sqrt": math.sqrt,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CSE (Common Subexpression Elimination)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _collect_nodes(expr: Expr, counts: Counter, visited: set[int]) -> None:
|
||||
"""Walk *expr* and count how many times each node ``id()`` appears."""
|
||||
eid = id(expr)
|
||||
counts[eid] += 1
|
||||
if eid in visited:
|
||||
return
|
||||
visited.add(eid)
|
||||
|
||||
# Recurse into children
|
||||
if isinstance(expr, (Const, Var)):
|
||||
return
|
||||
if hasattr(expr, "child"):
|
||||
_collect_nodes(expr.child, counts, visited)
|
||||
elif hasattr(expr, "a"):
|
||||
_collect_nodes(expr.a, counts, visited)
|
||||
_collect_nodes(expr.b, counts, visited)
|
||||
elif hasattr(expr, "base"):
|
||||
_collect_nodes(expr.base, counts, visited)
|
||||
_collect_nodes(expr.exp, counts, visited)
|
||||
|
||||
|
||||
def _build_cse(
|
||||
exprs: list[Expr],
|
||||
) -> tuple[dict[int, str], list[tuple[str, Expr]]]:
|
||||
"""Identify shared sub-trees and assign temporary variable names.
|
||||
|
||||
Returns:
|
||||
id_to_temp: mapping from ``id(node)`` to temp variable name
|
||||
temps_ordered: ``(temp_name, expr)`` pairs in dependency order
|
||||
"""
|
||||
counts: Counter = Counter()
|
||||
visited: set[int] = set()
|
||||
id_to_expr: dict[int, Expr] = {}
|
||||
|
||||
for expr in exprs:
|
||||
_collect_nodes(expr, counts, visited)
|
||||
|
||||
# Map id -> Expr for nodes we visited
|
||||
for expr in exprs:
|
||||
_map_ids(expr, id_to_expr)
|
||||
|
||||
# Nodes referenced more than once and not trivial (Const/Var) become temps
|
||||
shared_ids = set()
|
||||
for eid, cnt in counts.items():
|
||||
if cnt > 1:
|
||||
node = id_to_expr.get(eid)
|
||||
if node is not None and not isinstance(node, (Const, Var)):
|
||||
shared_ids.add(eid)
|
||||
|
||||
if not shared_ids:
|
||||
return {}, []
|
||||
|
||||
# Topological order: a temp must be computed before any temp that uses it.
|
||||
# Walk each shared node, collect in post-order.
|
||||
ordered_ids: list[int] = []
|
||||
order_visited: set[int] = set()
|
||||
|
||||
def _topo(expr: Expr) -> None:
|
||||
eid = id(expr)
|
||||
if eid in order_visited:
|
||||
return
|
||||
order_visited.add(eid)
|
||||
if isinstance(expr, (Const, Var)):
|
||||
return
|
||||
if hasattr(expr, "child"):
|
||||
_topo(expr.child)
|
||||
elif hasattr(expr, "a"):
|
||||
_topo(expr.a)
|
||||
_topo(expr.b)
|
||||
elif hasattr(expr, "base"):
|
||||
_topo(expr.base)
|
||||
_topo(expr.exp)
|
||||
if eid in shared_ids:
|
||||
ordered_ids.append(eid)
|
||||
|
||||
for expr in exprs:
|
||||
_topo(expr)
|
||||
|
||||
id_to_temp: dict[int, str] = {}
|
||||
temps_ordered: list[tuple[str, Expr]] = []
|
||||
for i, eid in enumerate(ordered_ids):
|
||||
name = f"_c{i}"
|
||||
id_to_temp[eid] = name
|
||||
temps_ordered.append((name, id_to_expr[eid]))
|
||||
|
||||
return id_to_temp, temps_ordered
|
||||
|
||||
|
||||
def _map_ids(expr: Expr, mapping: dict[int, Expr]) -> None:
|
||||
"""Populate id -> Expr mapping for all nodes in *expr*."""
|
||||
eid = id(expr)
|
||||
if eid in mapping:
|
||||
return
|
||||
mapping[eid] = expr
|
||||
if isinstance(expr, (Const, Var)):
|
||||
return
|
||||
if hasattr(expr, "child"):
|
||||
_map_ids(expr.child, mapping)
|
||||
elif hasattr(expr, "a"):
|
||||
_map_ids(expr.a, mapping)
|
||||
_map_ids(expr.b, mapping)
|
||||
elif hasattr(expr, "base"):
|
||||
_map_ids(expr.base, mapping)
|
||||
_map_ids(expr.exp, mapping)
|
||||
|
||||
|
||||
def _expr_to_code(expr: Expr, id_to_temp: dict[int, str]) -> str:
|
||||
"""Emit code for *expr*, substituting temp names for shared nodes."""
|
||||
eid = id(expr)
|
||||
temp = id_to_temp.get(eid)
|
||||
if temp is not None:
|
||||
return temp
|
||||
return expr.to_code()
|
||||
|
||||
|
||||
def _expr_to_code_recursive(expr: Expr, id_to_temp: dict[int, str]) -> str:
|
||||
"""Emit code for *expr*, recursing into children but respecting temps."""
|
||||
eid = id(expr)
|
||||
temp = id_to_temp.get(eid)
|
||||
if temp is not None:
|
||||
return temp
|
||||
|
||||
# For leaf nodes, just use to_code() directly
|
||||
if isinstance(expr, (Const, Var)):
|
||||
return expr.to_code()
|
||||
|
||||
# For non-leaf nodes, recurse into children with temp substitution
|
||||
from .expr import Add, Cos, Div, Mul, Neg, Pow, Sin, Sqrt, Sub
|
||||
|
||||
if isinstance(expr, Neg):
|
||||
return f"(-{_expr_to_code_recursive(expr.child, id_to_temp)})"
|
||||
if isinstance(expr, Sin):
|
||||
return f"_sin({_expr_to_code_recursive(expr.child, id_to_temp)})"
|
||||
if isinstance(expr, Cos):
|
||||
return f"_cos({_expr_to_code_recursive(expr.child, id_to_temp)})"
|
||||
if isinstance(expr, Sqrt):
|
||||
return f"_sqrt({_expr_to_code_recursive(expr.child, id_to_temp)})"
|
||||
if isinstance(expr, Add):
|
||||
a = _expr_to_code_recursive(expr.a, id_to_temp)
|
||||
b = _expr_to_code_recursive(expr.b, id_to_temp)
|
||||
return f"({a} + {b})"
|
||||
if isinstance(expr, Sub):
|
||||
a = _expr_to_code_recursive(expr.a, id_to_temp)
|
||||
b = _expr_to_code_recursive(expr.b, id_to_temp)
|
||||
return f"({a} - {b})"
|
||||
if isinstance(expr, Mul):
|
||||
a = _expr_to_code_recursive(expr.a, id_to_temp)
|
||||
b = _expr_to_code_recursive(expr.b, id_to_temp)
|
||||
return f"({a} * {b})"
|
||||
if isinstance(expr, Div):
|
||||
a = _expr_to_code_recursive(expr.a, id_to_temp)
|
||||
b = _expr_to_code_recursive(expr.b, id_to_temp)
|
||||
return f"({a} / {b})"
|
||||
if isinstance(expr, Pow):
|
||||
base = _expr_to_code_recursive(expr.base, id_to_temp)
|
||||
exp = _expr_to_code_recursive(expr.exp, id_to_temp)
|
||||
return f"({base} ** {exp})"
|
||||
|
||||
# Fallback — should not happen for known node types
|
||||
return expr.to_code()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sparsity detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _find_nonzero_entries(
|
||||
jac_exprs: list[list[Expr]],
|
||||
) -> list[tuple[int, int]]:
|
||||
"""Return ``(row, col)`` pairs for non-zero Jacobian entries."""
|
||||
nz = []
|
||||
for i, row in enumerate(jac_exprs):
|
||||
for j, expr in enumerate(row):
|
||||
if isinstance(expr, Const) and expr.value == 0.0:
|
||||
continue
|
||||
nz.append((i, j))
|
||||
return nz
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Code generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compile_system(
|
||||
residuals: list[Expr],
|
||||
jac_exprs: list[list[Expr]],
|
||||
n_res: int,
|
||||
n_free: int,
|
||||
) -> Callable[[dict, np.ndarray, np.ndarray], None]:
|
||||
"""Compile residuals + Jacobian into a single evaluation function.
|
||||
|
||||
Returns a callable ``fn(env, r_vec, J)`` that fills *r_vec* and *J*
|
||||
in-place. *J* must be pre-zeroed by the caller (only non-zero
|
||||
entries are written).
|
||||
"""
|
||||
# Detect non-zero Jacobian entries
|
||||
nz_entries = _find_nonzero_entries(jac_exprs)
|
||||
|
||||
# Collect all expressions for CSE analysis
|
||||
all_exprs: list[Expr] = list(residuals)
|
||||
nz_jac_exprs: list[Expr] = [jac_exprs[i][j] for i, j in nz_entries]
|
||||
all_exprs.extend(nz_jac_exprs)
|
||||
|
||||
# CSE
|
||||
id_to_temp, temps_ordered = _build_cse(all_exprs)
|
||||
|
||||
# Generate function body
|
||||
lines: list[str] = ["def _eval(env, r_vec, J):"]
|
||||
|
||||
# Temporaries — temporarily remove each temp's own id so its RHS
|
||||
# is expanded rather than self-referencing.
|
||||
for temp_name, temp_expr in temps_ordered:
|
||||
eid = id(temp_expr)
|
||||
saved = id_to_temp.pop(eid)
|
||||
code = _expr_to_code_recursive(temp_expr, id_to_temp)
|
||||
id_to_temp[eid] = saved
|
||||
lines.append(f" {temp_name} = {code}")
|
||||
|
||||
# Residuals
|
||||
for i, r in enumerate(residuals):
|
||||
code = _expr_to_code_recursive(r, id_to_temp)
|
||||
lines.append(f" r_vec[{i}] = {code}")
|
||||
|
||||
# Jacobian (sparse)
|
||||
for idx, (i, j) in enumerate(nz_entries):
|
||||
code = _expr_to_code_recursive(nz_jac_exprs[idx], id_to_temp)
|
||||
lines.append(f" J[{i}, {j}] = {code}")
|
||||
|
||||
source = "\n".join(lines)
|
||||
|
||||
# Compile
|
||||
code_obj = compile(source, "<kindred_codegen>", "exec")
|
||||
ns = dict(_CODEGEN_NS)
|
||||
exec(code_obj, ns)
|
||||
|
||||
fn = ns["_eval"]
|
||||
|
||||
n_temps = len(temps_ordered)
|
||||
n_nz = len(nz_entries)
|
||||
n_total = n_res * n_free
|
||||
log.debug(
|
||||
"codegen: compiled %d residuals + %d/%d Jacobian entries, %d CSE temps",
|
||||
n_res,
|
||||
n_nz,
|
||||
n_total,
|
||||
n_temps,
|
||||
)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def try_compile_system(
|
||||
residuals: list[Expr],
|
||||
jac_exprs: list[list[Expr]],
|
||||
n_res: int,
|
||||
n_free: int,
|
||||
) -> Callable[[dict, np.ndarray, np.ndarray], None] | None:
|
||||
"""Compile with automatic fallback. Returns ``None`` on failure."""
|
||||
try:
|
||||
return compile_system(residuals, jac_exprs, n_res, n_free)
|
||||
except Exception:
|
||||
log.debug(
|
||||
"codegen: compilation failed, falling back to tree-walk eval", exc_info=True
|
||||
)
|
||||
return None
|
||||
@@ -2,14 +2,28 @@
|
||||
|
||||
Each constraint takes two RigidBody entities and marker transforms,
|
||||
then generates residual expressions that equal zero when satisfied.
|
||||
|
||||
Phase 1 constraints: Coincident, DistancePointPoint, Fixed
|
||||
Phase 2 constraints: all remaining BaseJointKind types from Types.h
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
from .entities import RigidBody
|
||||
from .expr import Const, Expr
|
||||
from .geometry import (
|
||||
cross3,
|
||||
dot3,
|
||||
marker_x_axis,
|
||||
marker_y_axis,
|
||||
marker_z_axis,
|
||||
point_line_perp_components,
|
||||
point_plane_distance,
|
||||
sub3,
|
||||
)
|
||||
|
||||
|
||||
class ConstraintBase:
|
||||
@@ -63,9 +77,15 @@ class DistancePointPointConstraint(ConstraintBase):
|
||||
self.marker_j_pos = marker_j_pos
|
||||
self.distance = distance
|
||||
|
||||
def world_points(self) -> tuple[tuple[Expr, Expr, Expr], tuple[Expr, Expr, Expr]]:
|
||||
"""Return (world_point_i, world_point_j) expression tuples."""
|
||||
return (
|
||||
self.body_i.world_point(*self.marker_i_pos),
|
||||
self.body_j.world_point(*self.marker_j_pos),
|
||||
)
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
wx_i, wy_i, wz_i = self.body_i.world_point(*self.marker_i_pos)
|
||||
wx_j, wy_j, wz_j = self.body_j.world_point(*self.marker_j_pos)
|
||||
(wx_i, wy_i, wz_i), (wx_j, wy_j, wz_j) = self.world_points()
|
||||
dx = wx_i - wx_j
|
||||
dy = wy_i - wy_j
|
||||
dz = wz_i - wz_j
|
||||
@@ -145,6 +165,704 @@ class FixedConstraint(ConstraintBase):
|
||||
return pos_res + ori_res
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Phase 2: Point constraints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class PointOnLineConstraint(ConstraintBase):
|
||||
"""Point constrained to a line — 2 DOF removed.
|
||||
|
||||
marker_i origin lies on the line through marker_j origin along
|
||||
marker_j Z-axis. 2 residuals: perpendicular distance components.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body_i: RigidBody,
|
||||
marker_i_pos: tuple[float, float, float],
|
||||
marker_i_quat: tuple[float, float, float, float],
|
||||
body_j: RigidBody,
|
||||
marker_j_pos: tuple[float, float, float],
|
||||
marker_j_quat: tuple[float, float, float, float],
|
||||
):
|
||||
self.body_i = body_i
|
||||
self.body_j = body_j
|
||||
self.marker_i_pos = marker_i_pos
|
||||
self.marker_j_pos = marker_j_pos
|
||||
self.marker_j_quat = marker_j_quat
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
p_i = self.body_i.world_point(*self.marker_i_pos)
|
||||
p_j = self.body_j.world_point(*self.marker_j_pos)
|
||||
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
|
||||
cx, cy, cz = point_line_perp_components(p_i, p_j, z_j)
|
||||
return [cx, cy, cz]
|
||||
|
||||
|
||||
class PointInPlaneConstraint(ConstraintBase):
|
||||
"""Point constrained to a plane — 1 DOF removed.
|
||||
|
||||
marker_i origin lies in the plane through marker_j origin with
|
||||
normal = marker_j Z-axis. Optional offset via params[0].
|
||||
1 residual: signed distance to plane.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body_i: RigidBody,
|
||||
marker_i_pos: tuple[float, float, float],
|
||||
marker_i_quat: tuple[float, float, float, float],
|
||||
body_j: RigidBody,
|
||||
marker_j_pos: tuple[float, float, float],
|
||||
marker_j_quat: tuple[float, float, float, float],
|
||||
offset: float = 0.0,
|
||||
):
|
||||
self.body_i = body_i
|
||||
self.body_j = body_j
|
||||
self.marker_i_pos = marker_i_pos
|
||||
self.marker_j_pos = marker_j_pos
|
||||
self.marker_j_quat = marker_j_quat
|
||||
self.offset = offset
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
p_i = self.body_i.world_point(*self.marker_i_pos)
|
||||
p_j = self.body_j.world_point(*self.marker_j_pos)
|
||||
n_j = marker_z_axis(self.body_j, self.marker_j_quat)
|
||||
d = point_plane_distance(p_i, p_j, n_j)
|
||||
if self.offset != 0.0:
|
||||
d = d - Const(self.offset)
|
||||
return [d]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Phase 2: Axis orientation constraints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ParallelConstraint(ConstraintBase):
|
||||
"""Parallel axes — 2 DOF removed.
|
||||
|
||||
marker Z-axes are parallel: z_i x z_j = 0.
|
||||
3 cross-product residuals (rank 2 at the solution). Using all 3
|
||||
avoids a singularity when the axes lie in the XY plane, where
|
||||
dropping cz would leave the constraint blind to yaw rotations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body_i: RigidBody,
|
||||
marker_i_quat: tuple[float, float, float, float],
|
||||
body_j: RigidBody,
|
||||
marker_j_quat: tuple[float, float, float, float],
|
||||
):
|
||||
self.body_i = body_i
|
||||
self.body_j = body_j
|
||||
self.marker_i_quat = marker_i_quat
|
||||
self.marker_j_quat = marker_j_quat
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
|
||||
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
|
||||
cx, cy, cz = cross3(z_i, z_j)
|
||||
return [cx, cy, cz]
|
||||
|
||||
|
||||
class PerpendicularConstraint(ConstraintBase):
|
||||
"""Perpendicular axes — 1 DOF removed.
|
||||
|
||||
marker Z-axes are perpendicular: z_i . z_j = 0.
|
||||
1 residual.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body_i: RigidBody,
|
||||
marker_i_quat: tuple[float, float, float, float],
|
||||
body_j: RigidBody,
|
||||
marker_j_quat: tuple[float, float, float, float],
|
||||
):
|
||||
self.body_i = body_i
|
||||
self.body_j = body_j
|
||||
self.marker_i_quat = marker_i_quat
|
||||
self.marker_j_quat = marker_j_quat
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
|
||||
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
|
||||
return [dot3(z_i, z_j)]
|
||||
|
||||
|
||||
class AngleConstraint(ConstraintBase):
|
||||
"""Angle between axes — 1 DOF removed.
|
||||
|
||||
z_i . z_j = cos(angle).
|
||||
1 residual.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body_i: RigidBody,
|
||||
marker_i_quat: tuple[float, float, float, float],
|
||||
body_j: RigidBody,
|
||||
marker_j_quat: tuple[float, float, float, float],
|
||||
angle: float,
|
||||
):
|
||||
self.body_i = body_i
|
||||
self.body_j = body_j
|
||||
self.marker_i_quat = marker_i_quat
|
||||
self.marker_j_quat = marker_j_quat
|
||||
self.angle = angle
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
|
||||
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
|
||||
return [dot3(z_i, z_j) - Const(math.cos(self.angle))]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Phase 2: Axis/surface constraints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ConcentricConstraint(ConstraintBase):
|
||||
"""Coaxial / concentric — 4 DOF removed.
|
||||
|
||||
Axes are collinear: parallel Z-axes (2) + point-on-line (2).
|
||||
Optional distance offset along axis via params.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body_i: RigidBody,
|
||||
marker_i_pos: tuple[float, float, float],
|
||||
marker_i_quat: tuple[float, float, float, float],
|
||||
body_j: RigidBody,
|
||||
marker_j_pos: tuple[float, float, float],
|
||||
marker_j_quat: tuple[float, float, float, float],
|
||||
distance: float = 0.0,
|
||||
):
|
||||
self.body_i = body_i
|
||||
self.body_j = body_j
|
||||
self.marker_i_pos = marker_i_pos
|
||||
self.marker_i_quat = marker_i_quat
|
||||
self.marker_j_pos = marker_j_pos
|
||||
self.marker_j_quat = marker_j_quat
|
||||
self.distance = distance
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
# Parallel axes (3 cross-product residuals, rank 2 at solution)
|
||||
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
|
||||
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
|
||||
cx, cy, cz = cross3(z_i, z_j)
|
||||
|
||||
# Point-on-line: marker_i origin on line through marker_j along z_j
|
||||
p_i = self.body_i.world_point(*self.marker_i_pos)
|
||||
p_j = self.body_j.world_point(*self.marker_j_pos)
|
||||
lx, ly, lz = point_line_perp_components(p_i, p_j, z_j)
|
||||
|
||||
return [cx, cy, cz, lx, ly, lz]
|
||||
|
||||
|
||||
class TangentConstraint(ConstraintBase):
|
||||
"""Face-on-face tangency — 1 DOF removed.
|
||||
|
||||
Signed distance between marker origins along marker_j normal = 0.
|
||||
1 residual: (p_i - p_j) . z_j
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body_i: RigidBody,
|
||||
marker_i_pos: tuple[float, float, float],
|
||||
marker_i_quat: tuple[float, float, float, float],
|
||||
body_j: RigidBody,
|
||||
marker_j_pos: tuple[float, float, float],
|
||||
marker_j_quat: tuple[float, float, float, float],
|
||||
):
|
||||
self.body_i = body_i
|
||||
self.body_j = body_j
|
||||
self.marker_i_pos = marker_i_pos
|
||||
self.marker_j_pos = marker_j_pos
|
||||
self.marker_j_quat = marker_j_quat
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
p_i = self.body_i.world_point(*self.marker_i_pos)
|
||||
p_j = self.body_j.world_point(*self.marker_j_pos)
|
||||
n_j = marker_z_axis(self.body_j, self.marker_j_quat)
|
||||
return [point_plane_distance(p_i, p_j, n_j)]
|
||||
|
||||
|
||||
class PlanarConstraint(ConstraintBase):
|
||||
"""Coplanar faces — 3 DOF removed.
|
||||
|
||||
Parallel normals (2) + point-in-plane (1). Optional offset.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body_i: RigidBody,
|
||||
marker_i_pos: tuple[float, float, float],
|
||||
marker_i_quat: tuple[float, float, float, float],
|
||||
body_j: RigidBody,
|
||||
marker_j_pos: tuple[float, float, float],
|
||||
marker_j_quat: tuple[float, float, float, float],
|
||||
offset: float = 0.0,
|
||||
):
|
||||
self.body_i = body_i
|
||||
self.body_j = body_j
|
||||
self.marker_i_pos = marker_i_pos
|
||||
self.marker_i_quat = marker_i_quat
|
||||
self.marker_j_pos = marker_j_pos
|
||||
self.marker_j_quat = marker_j_quat
|
||||
self.offset = offset
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
# Parallel normals (3 cross-product residuals, rank 2 at solution)
|
||||
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
|
||||
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
|
||||
cx, cy, cz = cross3(z_i, z_j)
|
||||
|
||||
# Point-in-plane
|
||||
p_i = self.body_i.world_point(*self.marker_i_pos)
|
||||
p_j = self.body_j.world_point(*self.marker_j_pos)
|
||||
d = point_plane_distance(p_i, p_j, z_j)
|
||||
if self.offset != 0.0:
|
||||
d = d - Const(self.offset)
|
||||
|
||||
return [cx, cy, cz, d]
|
||||
|
||||
|
||||
class LineInPlaneConstraint(ConstraintBase):
|
||||
"""Line constrained to a plane — 2 DOF removed.
|
||||
|
||||
Line defined by marker_i Z-axis lies in plane defined by marker_j normal.
|
||||
2 residuals: point-in-plane (1) + line direction perpendicular to normal (1).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body_i: RigidBody,
|
||||
marker_i_pos: tuple[float, float, float],
|
||||
marker_i_quat: tuple[float, float, float, float],
|
||||
body_j: RigidBody,
|
||||
marker_j_pos: tuple[float, float, float],
|
||||
marker_j_quat: tuple[float, float, float, float],
|
||||
offset: float = 0.0,
|
||||
):
|
||||
self.body_i = body_i
|
||||
self.body_j = body_j
|
||||
self.marker_i_pos = marker_i_pos
|
||||
self.marker_i_quat = marker_i_quat
|
||||
self.marker_j_pos = marker_j_pos
|
||||
self.marker_j_quat = marker_j_quat
|
||||
self.offset = offset
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
p_i = self.body_i.world_point(*self.marker_i_pos)
|
||||
p_j = self.body_j.world_point(*self.marker_j_pos)
|
||||
n_j = marker_z_axis(self.body_j, self.marker_j_quat)
|
||||
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
|
||||
|
||||
# Point in plane
|
||||
d = point_plane_distance(p_i, p_j, n_j)
|
||||
if self.offset != 0.0:
|
||||
d = d - Const(self.offset)
|
||||
|
||||
# Line direction perpendicular to plane normal
|
||||
dir_dot = dot3(z_i, n_j)
|
||||
|
||||
return [d, dir_dot]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Phase 2: Kinematic joints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class BallConstraint(ConstraintBase):
|
||||
"""Spherical joint — 3 DOF removed.
|
||||
|
||||
Coincident marker origins. Same as CoincidentConstraint.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body_i: RigidBody,
|
||||
marker_i_pos: tuple[float, float, float],
|
||||
body_j: RigidBody,
|
||||
marker_j_pos: tuple[float, float, float],
|
||||
):
|
||||
self._inner = CoincidentConstraint(body_i, marker_i_pos, body_j, marker_j_pos)
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
return self._inner.residuals()
|
||||
|
||||
|
||||
class RevoluteConstraint(ConstraintBase):
|
||||
"""Hinge joint — 5 DOF removed.
|
||||
|
||||
Coincident origins (3) + parallel Z-axes (2).
|
||||
1 rotational DOF remains (about the common Z-axis).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body_i: RigidBody,
|
||||
marker_i_pos: tuple[float, float, float],
|
||||
marker_i_quat: tuple[float, float, float, float],
|
||||
body_j: RigidBody,
|
||||
marker_j_pos: tuple[float, float, float],
|
||||
marker_j_quat: tuple[float, float, float, float],
|
||||
):
|
||||
self.body_i = body_i
|
||||
self.body_j = body_j
|
||||
self.marker_i_pos = marker_i_pos
|
||||
self.marker_i_quat = marker_i_quat
|
||||
self.marker_j_pos = marker_j_pos
|
||||
self.marker_j_quat = marker_j_quat
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
# Coincident origins
|
||||
p_i = self.body_i.world_point(*self.marker_i_pos)
|
||||
p_j = self.body_j.world_point(*self.marker_j_pos)
|
||||
pos = [p_i[0] - p_j[0], p_i[1] - p_j[1], p_i[2] - p_j[2]]
|
||||
|
||||
# Parallel Z-axes (3 cross-product residuals, rank 2 at solution)
|
||||
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
|
||||
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
|
||||
cx, cy, cz = cross3(z_i, z_j)
|
||||
|
||||
return pos + [cx, cy, cz]
|
||||
|
||||
|
||||
class CylindricalConstraint(ConstraintBase):
|
||||
"""Cylindrical joint — 4 DOF removed.
|
||||
|
||||
Parallel Z-axes (2) + point-on-line (2).
|
||||
2 DOF remain: rotation about and translation along the common axis.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body_i: RigidBody,
|
||||
marker_i_pos: tuple[float, float, float],
|
||||
marker_i_quat: tuple[float, float, float, float],
|
||||
body_j: RigidBody,
|
||||
marker_j_pos: tuple[float, float, float],
|
||||
marker_j_quat: tuple[float, float, float, float],
|
||||
):
|
||||
self.body_i = body_i
|
||||
self.body_j = body_j
|
||||
self.marker_i_pos = marker_i_pos
|
||||
self.marker_i_quat = marker_i_quat
|
||||
self.marker_j_pos = marker_j_pos
|
||||
self.marker_j_quat = marker_j_quat
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
# Parallel Z-axes (3 cross-product residuals, rank 2 at solution)
|
||||
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
|
||||
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
|
||||
cx, cy, cz = cross3(z_i, z_j)
|
||||
|
||||
# Point-on-line
|
||||
p_i = self.body_i.world_point(*self.marker_i_pos)
|
||||
p_j = self.body_j.world_point(*self.marker_j_pos)
|
||||
lx, ly, lz = point_line_perp_components(p_i, p_j, z_j)
|
||||
|
||||
return [cx, cy, cz, lx, ly, lz]
|
||||
|
||||
|
||||
class SliderConstraint(ConstraintBase):
|
||||
"""Prismatic / slider joint — 5 DOF removed.
|
||||
|
||||
Parallel Z-axes (2) + point-on-line (2) + rotation lock (1).
|
||||
1 DOF remains: translation along the common Z-axis.
|
||||
|
||||
Rotation lock: x_i . y_j = 0 (prevents twist about Z).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body_i: RigidBody,
|
||||
marker_i_pos: tuple[float, float, float],
|
||||
marker_i_quat: tuple[float, float, float, float],
|
||||
body_j: RigidBody,
|
||||
marker_j_pos: tuple[float, float, float],
|
||||
marker_j_quat: tuple[float, float, float, float],
|
||||
):
|
||||
self.body_i = body_i
|
||||
self.body_j = body_j
|
||||
self.marker_i_pos = marker_i_pos
|
||||
self.marker_i_quat = marker_i_quat
|
||||
self.marker_j_pos = marker_j_pos
|
||||
self.marker_j_quat = marker_j_quat
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
# Parallel Z-axes (3 cross-product residuals, rank 2 at solution)
|
||||
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
|
||||
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
|
||||
cx, cy, cz = cross3(z_i, z_j)
|
||||
|
||||
# Point-on-line
|
||||
p_i = self.body_i.world_point(*self.marker_i_pos)
|
||||
p_j = self.body_j.world_point(*self.marker_j_pos)
|
||||
lx, ly, lz = point_line_perp_components(p_i, p_j, z_j)
|
||||
|
||||
# Rotation lock: x_i . y_j = 0
|
||||
x_i = marker_x_axis(self.body_i, self.marker_i_quat)
|
||||
y_j = marker_y_axis(self.body_j, self.marker_j_quat)
|
||||
twist = dot3(x_i, y_j)
|
||||
|
||||
return [cx, cy, cz, lx, ly, lz, twist]
|
||||
|
||||
|
||||
class ScrewConstraint(ConstraintBase):
|
||||
"""Helical / screw joint — 5 DOF removed.
|
||||
|
||||
Cylindrical (4) + coupled rotation-translation via pitch (1).
|
||||
1 DOF remains: screw motion (rotation + proportional translation).
|
||||
|
||||
The coupling residual uses the relative quaternion's Z-component
|
||||
(proportional to the rotation angle for small angles) and the axial
|
||||
displacement: axial_disp - pitch * (2 * qz_rel / qw_rel) / (2*pi) = 0.
|
||||
For the Newton solver operating near the solution, the linear
|
||||
approximation angle ≈ 2 * qz_rel is adequate.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body_i: RigidBody,
|
||||
marker_i_pos: tuple[float, float, float],
|
||||
marker_i_quat: tuple[float, float, float, float],
|
||||
body_j: RigidBody,
|
||||
marker_j_pos: tuple[float, float, float],
|
||||
marker_j_quat: tuple[float, float, float, float],
|
||||
pitch: float = 1.0,
|
||||
):
|
||||
self.body_i = body_i
|
||||
self.body_j = body_j
|
||||
self.marker_i_pos = marker_i_pos
|
||||
self.marker_i_quat = marker_i_quat
|
||||
self.marker_j_pos = marker_j_pos
|
||||
self.marker_j_quat = marker_j_quat
|
||||
self.pitch = pitch
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
# Cylindrical residuals (5: 3 parallel + 2 point-on-line)
|
||||
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
|
||||
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
|
||||
cx, cy, cz = cross3(z_i, z_j)
|
||||
|
||||
p_i = self.body_i.world_point(*self.marker_i_pos)
|
||||
p_j = self.body_j.world_point(*self.marker_j_pos)
|
||||
lx, ly, lz = point_line_perp_components(p_i, p_j, z_j)
|
||||
|
||||
# Pitch coupling: axial_disp = pitch * angle / (2*pi)
|
||||
# Axial displacement
|
||||
d = sub3(p_i, p_j)
|
||||
axial = dot3(d, z_j)
|
||||
|
||||
# Relative rotation about Z via quaternion
|
||||
# q_rel = conj(q_i_total) * q_j_total
|
||||
qi = _quat_mul_const(
|
||||
self.body_i.qw,
|
||||
self.body_i.qx,
|
||||
self.body_i.qy,
|
||||
self.body_i.qz,
|
||||
*self.marker_i_quat,
|
||||
)
|
||||
qj = _quat_mul_const(
|
||||
self.body_j.qw,
|
||||
self.body_j.qx,
|
||||
self.body_j.qy,
|
||||
self.body_j.qz,
|
||||
*self.marker_j_quat,
|
||||
)
|
||||
rel = _quat_mul_expr(qi[0], -qi[1], -qi[2], -qi[3], qj[0], qj[1], qj[2], qj[3])
|
||||
# For small angles: angle ≈ 2 * qz_rel, but qw_rel ≈ 1
|
||||
# Use sin(angle/2) form: residual = axial - pitch * 2*qz / (2*pi)
|
||||
# = axial - pitch * qz / pi
|
||||
coupling = axial - Const(self.pitch / math.pi) * rel[3]
|
||||
|
||||
return [cx, cy, cz, lx, ly, lz, coupling]
|
||||
|
||||
|
||||
class UniversalConstraint(ConstraintBase):
|
||||
"""Universal / Cardan joint — 4 DOF removed.
|
||||
|
||||
Coincident origins (3) + perpendicular Z-axes (1).
|
||||
2 DOF remain: rotation about each body's Z-axis.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body_i: RigidBody,
|
||||
marker_i_pos: tuple[float, float, float],
|
||||
marker_i_quat: tuple[float, float, float, float],
|
||||
body_j: RigidBody,
|
||||
marker_j_pos: tuple[float, float, float],
|
||||
marker_j_quat: tuple[float, float, float, float],
|
||||
):
|
||||
self.body_i = body_i
|
||||
self.body_j = body_j
|
||||
self.marker_i_pos = marker_i_pos
|
||||
self.marker_i_quat = marker_i_quat
|
||||
self.marker_j_pos = marker_j_pos
|
||||
self.marker_j_quat = marker_j_quat
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
# Coincident origins
|
||||
p_i = self.body_i.world_point(*self.marker_i_pos)
|
||||
p_j = self.body_j.world_point(*self.marker_j_pos)
|
||||
pos = [p_i[0] - p_j[0], p_i[1] - p_j[1], p_i[2] - p_j[2]]
|
||||
|
||||
# Perpendicular Z-axes
|
||||
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
|
||||
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
|
||||
|
||||
return pos + [dot3(z_i, z_j)]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Phase 2: Mechanical element constraints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class GearConstraint(ConstraintBase):
|
||||
"""Gear pair or belt — 1 DOF removed.
|
||||
|
||||
Couples rotation angles: r_i * theta_i + r_j * theta_j = 0.
|
||||
For belts (same-direction rotation), r_j is passed as negative.
|
||||
|
||||
Uses the Z-component of the relative quaternion as a proxy for
|
||||
rotation angle (linear for small angles, which is the regime
|
||||
where Newton operates).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body_i: RigidBody,
|
||||
marker_i_quat: tuple[float, float, float, float],
|
||||
body_j: RigidBody,
|
||||
marker_j_quat: tuple[float, float, float, float],
|
||||
radius_i: float,
|
||||
radius_j: float,
|
||||
):
|
||||
self.body_i = body_i
|
||||
self.body_j = body_j
|
||||
self.marker_i_quat = marker_i_quat
|
||||
self.marker_j_quat = marker_j_quat
|
||||
self.radius_i = radius_i
|
||||
self.radius_j = radius_j
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
# Rotation angle proxy via relative quaternion Z-component
|
||||
# For body_i: q_rel_i = conj(q_marker_i) * q_body_i * q_marker_i
|
||||
# Simplified: use 2*qz of (conj(marker) * body * marker) as angle proxy
|
||||
qz_i = _rotation_z_component(self.body_i, self.marker_i_quat)
|
||||
qz_j = _rotation_z_component(self.body_j, self.marker_j_quat)
|
||||
|
||||
# r_i * theta_i + r_j * theta_j = 0
|
||||
# Using qz as proportional to theta/2:
|
||||
# r_i * qz_i + r_j * qz_j = 0
|
||||
return [Const(self.radius_i) * qz_i + Const(self.radius_j) * qz_j]
|
||||
|
||||
|
||||
class RackPinionConstraint(ConstraintBase):
|
||||
"""Rack-and-pinion — 1 DOF removed.
|
||||
|
||||
Couples rotation of body_i to translation of body_j along marker_j Z-axis.
|
||||
translation = pitch_radius * theta
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body_i: RigidBody,
|
||||
marker_i_pos: tuple[float, float, float],
|
||||
marker_i_quat: tuple[float, float, float, float],
|
||||
body_j: RigidBody,
|
||||
marker_j_pos: tuple[float, float, float],
|
||||
marker_j_quat: tuple[float, float, float, float],
|
||||
pitch_radius: float,
|
||||
):
|
||||
self.body_i = body_i
|
||||
self.body_j = body_j
|
||||
self.marker_i_pos = marker_i_pos
|
||||
self.marker_i_quat = marker_i_quat
|
||||
self.marker_j_pos = marker_j_pos
|
||||
self.marker_j_quat = marker_j_quat
|
||||
self.pitch_radius = pitch_radius
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
# Translation of j along its Z-axis
|
||||
p_i = self.body_i.world_point(*self.marker_i_pos)
|
||||
p_j = self.body_j.world_point(*self.marker_j_pos)
|
||||
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
|
||||
d = sub3(p_j, p_i)
|
||||
translation = dot3(d, z_j)
|
||||
|
||||
# Rotation angle of i about its Z-axis
|
||||
qz_i = _rotation_z_component(self.body_i, self.marker_i_quat)
|
||||
|
||||
# translation - pitch_radius * angle = 0
|
||||
# angle ≈ 2 * qz, so: translation - pitch_radius * 2 * qz = 0
|
||||
return [translation - Const(2.0 * self.pitch_radius) * qz_i]
|
||||
|
||||
|
||||
class CamConstraint(ConstraintBase):
|
||||
"""Cam-follower constraint — future, stub."""
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
return []
|
||||
|
||||
|
||||
class SlotConstraint(ConstraintBase):
|
||||
"""Slot constraint — future, stub."""
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
return []
|
||||
|
||||
|
||||
class DistanceCylSphConstraint(ConstraintBase):
|
||||
"""Cylinder-sphere distance — stub.
|
||||
|
||||
Semantics depend on geometry classification; placeholder for now.
|
||||
"""
|
||||
|
||||
def residuals(self) -> List[Expr]:
|
||||
return []
|
||||
|
||||
|
||||
# -- rotation helpers for mechanical constraints ------------------------------
|
||||
|
||||
|
||||
def _rotation_z_component(
|
||||
body: RigidBody,
|
||||
marker_quat: tuple[float, float, float, float],
|
||||
) -> Expr:
|
||||
"""Extract the Z-component of the relative quaternion about a marker axis.
|
||||
|
||||
Returns the qz component of conj(q_marker) * q_body * q_marker,
|
||||
which is proportional to sin(theta/2) where theta is the rotation
|
||||
angle about the marker Z-axis.
|
||||
"""
|
||||
mw, mx, my, mz = marker_quat
|
||||
# q_local = conj(marker) * q_body * marker
|
||||
# Step 1: temp = conj(marker) * q_body
|
||||
cmw, cmx, cmy, cmz = Const(mw), Const(-mx), Const(-my), Const(-mz)
|
||||
# temp = conj(marker) * q_body
|
||||
tw = cmw * body.qw - cmx * body.qx - cmy * body.qy - cmz * body.qz
|
||||
tx = cmw * body.qx + cmx * body.qw + cmy * body.qz - cmz * body.qy
|
||||
ty = cmw * body.qy - cmx * body.qz + cmy * body.qw + cmz * body.qx
|
||||
tz = cmw * body.qz + cmx * body.qy - cmy * body.qx + cmz * body.qw
|
||||
# q_local = temp * marker
|
||||
mmw, mmx, mmy, mmz = Const(mw), Const(mx), Const(my), Const(mz)
|
||||
# rz = tw * mmz + tx * mmy - ty * mmx + tz * mmw
|
||||
rz = tw * mmz + tx * mmy - ty * mmx + tz * mmw
|
||||
return rz
|
||||
|
||||
|
||||
# -- quaternion multiplication helpers ----------------------------------------
|
||||
|
||||
|
||||
|
||||
680
kindred_solver/decompose.py
Normal file
680
kindred_solver/decompose.py
Normal file
@@ -0,0 +1,680 @@
|
||||
"""Graph decomposition for cluster-by-cluster constraint solving.
|
||||
|
||||
Builds a constraint graph from the SolveContext, decomposes it into
|
||||
biconnected components (rigid clusters), orders them via a block-cut
|
||||
tree, and solves each cluster independently. Articulation-point bodies
|
||||
are temporarily fixed when solving adjacent clusters so their solved
|
||||
values propagate as boundary conditions.
|
||||
|
||||
Requires: networkx
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
import sys
|
||||
import types as stdlib_types
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from .bfgs import bfgs_solve
|
||||
from .newton import newton_solve
|
||||
from .prepass import substitution_pass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .constraints import ConstraintBase
|
||||
from .entities import RigidBody
|
||||
from .params import ParamTable
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DOF table: BaseJointKind → number of residuals (= DOF removed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Imported lazily to avoid hard kcsolve dependency in tests.
|
||||
# Use residual_count() accessor instead of this dict directly.
|
||||
_RESIDUAL_COUNT: dict[str, int] | None = None
|
||||
|
||||
|
||||
def _ensure_residual_count() -> dict:
|
||||
"""Build the residual count table on first use."""
|
||||
global _RESIDUAL_COUNT
|
||||
if _RESIDUAL_COUNT is not None:
|
||||
return _RESIDUAL_COUNT
|
||||
|
||||
import kcsolve
|
||||
|
||||
_RESIDUAL_COUNT = {
|
||||
kcsolve.BaseJointKind.Fixed: 6,
|
||||
kcsolve.BaseJointKind.Coincident: 3,
|
||||
kcsolve.BaseJointKind.Ball: 3,
|
||||
kcsolve.BaseJointKind.Revolute: 5,
|
||||
kcsolve.BaseJointKind.Cylindrical: 4,
|
||||
kcsolve.BaseJointKind.Slider: 5,
|
||||
kcsolve.BaseJointKind.Screw: 5,
|
||||
kcsolve.BaseJointKind.Universal: 4,
|
||||
kcsolve.BaseJointKind.Parallel: 2,
|
||||
kcsolve.BaseJointKind.Perpendicular: 1,
|
||||
kcsolve.BaseJointKind.Angle: 1,
|
||||
kcsolve.BaseJointKind.Concentric: 4,
|
||||
kcsolve.BaseJointKind.Tangent: 1,
|
||||
kcsolve.BaseJointKind.Planar: 3,
|
||||
kcsolve.BaseJointKind.LineInPlane: 2,
|
||||
kcsolve.BaseJointKind.PointOnLine: 2,
|
||||
kcsolve.BaseJointKind.PointInPlane: 1,
|
||||
kcsolve.BaseJointKind.DistancePointPoint: 1,
|
||||
kcsolve.BaseJointKind.Gear: 1,
|
||||
kcsolve.BaseJointKind.RackPinion: 1,
|
||||
kcsolve.BaseJointKind.Cam: 0,
|
||||
kcsolve.BaseJointKind.Slot: 0,
|
||||
kcsolve.BaseJointKind.DistanceCylSph: 0,
|
||||
}
|
||||
return _RESIDUAL_COUNT
|
||||
|
||||
|
||||
def residual_count(kind) -> int:
|
||||
"""Number of residuals a constraint type produces."""
|
||||
return _ensure_residual_count().get(kind, 0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Standalone residual-count table (no kcsolve dependency, string-keyed)
|
||||
# Used by tests that don't have kcsolve available.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_RESIDUAL_COUNT_BY_NAME: dict[str, int] = {
|
||||
"Fixed": 6,
|
||||
"Coincident": 3,
|
||||
"Ball": 3,
|
||||
"Revolute": 5,
|
||||
"Cylindrical": 4,
|
||||
"Slider": 5,
|
||||
"Screw": 5,
|
||||
"Universal": 4,
|
||||
"Parallel": 2,
|
||||
"Perpendicular": 1,
|
||||
"Angle": 1,
|
||||
"Concentric": 4,
|
||||
"Tangent": 1,
|
||||
"Planar": 3,
|
||||
"LineInPlane": 2,
|
||||
"PointOnLine": 2,
|
||||
"PointInPlane": 1,
|
||||
"DistancePointPoint": 1,
|
||||
"Gear": 1,
|
||||
"RackPinion": 1,
|
||||
"Cam": 0,
|
||||
"Slot": 0,
|
||||
"DistanceCylSph": 0,
|
||||
}
|
||||
|
||||
|
||||
def residual_count_by_name(kind_name: str) -> int:
|
||||
"""Number of residuals by constraint type name (no kcsolve needed)."""
|
||||
return _RESIDUAL_COUNT_BY_NAME.get(kind_name, 0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data structures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class SolveCluster:
|
||||
"""A cluster of bodies to solve together."""
|
||||
|
||||
bodies: set[str] # Body IDs in this cluster
|
||||
constraint_indices: list[int] # Indices into the constraint list
|
||||
boundary_bodies: set[str] # Articulation points shared with other clusters
|
||||
has_ground: bool # Whether any body in the cluster is grounded
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Graph construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_constraint_graph(
|
||||
constraints: list,
|
||||
grounded_bodies: set[str],
|
||||
) -> nx.MultiGraph:
|
||||
"""Build a body-level constraint multigraph.
|
||||
|
||||
Nodes: part_id strings (one per body referenced by constraints).
|
||||
Edges: one per active constraint with attributes:
|
||||
- constraint_index: position in the constraints list
|
||||
- weight: number of residuals
|
||||
|
||||
Grounded bodies are tagged with ``grounded=True``.
|
||||
Constraints with 0 residuals (stubs) are excluded.
|
||||
"""
|
||||
G = nx.MultiGraph()
|
||||
|
||||
for idx, c in enumerate(constraints):
|
||||
if not c.activated:
|
||||
continue
|
||||
weight = residual_count(c.type)
|
||||
if weight == 0:
|
||||
continue
|
||||
part_i = c.part_i
|
||||
part_j = c.part_j
|
||||
# Ensure nodes exist
|
||||
if part_i not in G:
|
||||
G.add_node(part_i, grounded=(part_i in grounded_bodies))
|
||||
if part_j not in G:
|
||||
G.add_node(part_j, grounded=(part_j in grounded_bodies))
|
||||
# Store kind_name for pebble game integration
|
||||
kind_name = c.type.name if hasattr(c.type, "name") else str(c.type)
|
||||
G.add_edge(
|
||||
part_i, part_j, constraint_index=idx, weight=weight, kind_name=kind_name
|
||||
)
|
||||
|
||||
return G
|
||||
|
||||
|
||||
def build_constraint_graph_simple(
|
||||
edges: list[tuple[str, str, str, int]],
|
||||
grounded: set[str] | None = None,
|
||||
) -> nx.MultiGraph:
|
||||
"""Build a constraint graph from simple edge tuples (for testing).
|
||||
|
||||
Each edge is ``(body_i, body_j, kind_name, constraint_index)``.
|
||||
"""
|
||||
grounded = grounded or set()
|
||||
G = nx.MultiGraph()
|
||||
for body_i, body_j, kind_name, idx in edges:
|
||||
weight = residual_count_by_name(kind_name)
|
||||
if weight == 0:
|
||||
continue
|
||||
if body_i not in G:
|
||||
G.add_node(body_i, grounded=(body_i in grounded))
|
||||
if body_j not in G:
|
||||
G.add_node(body_j, grounded=(body_j in grounded))
|
||||
G.add_edge(
|
||||
body_i, body_j, constraint_index=idx, weight=weight, kind_name=kind_name
|
||||
)
|
||||
return G
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decomposition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def find_clusters(
|
||||
G: nx.MultiGraph,
|
||||
) -> tuple[list[set[str]], set[str]]:
|
||||
"""Find biconnected components and articulation points.
|
||||
|
||||
Returns:
|
||||
clusters: list of body-ID sets (one per biconnected component)
|
||||
articulation_points: body-IDs shared between clusters
|
||||
"""
|
||||
# biconnected_components requires a simple Graph
|
||||
simple = nx.Graph(G)
|
||||
clusters = [set(c) for c in nx.biconnected_components(simple)]
|
||||
artic = set(nx.articulation_points(simple))
|
||||
return clusters, artic
|
||||
|
||||
|
||||
def build_solve_order(
|
||||
G: nx.MultiGraph,
|
||||
clusters: list[set[str]],
|
||||
articulation_points: set[str],
|
||||
grounded_bodies: set[str],
|
||||
) -> list[SolveCluster]:
|
||||
"""Order clusters for solving via the block-cut tree.
|
||||
|
||||
Builds the block-cut tree (bipartite graph of clusters and
|
||||
articulation points), roots it at a grounded cluster, and returns
|
||||
clusters in root-to-leaf order (grounded first, outward to leaves).
|
||||
This ensures boundary bodies are solved before clusters that
|
||||
depend on them.
|
||||
"""
|
||||
if not clusters:
|
||||
return []
|
||||
|
||||
# Single cluster — no ordering needed
|
||||
if len(clusters) == 1:
|
||||
bodies = clusters[0]
|
||||
indices = _constraints_for_bodies(G, bodies)
|
||||
has_ground = bool(bodies & grounded_bodies)
|
||||
return [
|
||||
SolveCluster(
|
||||
bodies=bodies,
|
||||
constraint_indices=indices,
|
||||
boundary_bodies=set(),
|
||||
has_ground=has_ground,
|
||||
)
|
||||
]
|
||||
|
||||
# Build block-cut tree
|
||||
# Nodes: ("C", i) for cluster i, ("A", body_id) for articulation points
|
||||
bct = nx.Graph()
|
||||
for i, cluster in enumerate(clusters):
|
||||
bct.add_node(("C", i))
|
||||
for ap in articulation_points:
|
||||
if ap in cluster:
|
||||
bct.add_edge(("C", i), ("A", ap))
|
||||
|
||||
# Find root: prefer a cluster containing a grounded body
|
||||
root = ("C", 0)
|
||||
for i, cluster in enumerate(clusters):
|
||||
if cluster & grounded_bodies:
|
||||
root = ("C", i)
|
||||
break
|
||||
|
||||
# BFS from root: grounded cluster first, outward to leaves
|
||||
visited = set()
|
||||
order = []
|
||||
queue = deque([root])
|
||||
visited.add(root)
|
||||
while queue:
|
||||
node = queue.popleft()
|
||||
if node[0] == "C":
|
||||
order.append(node[1])
|
||||
for neighbor in bct.neighbors(node):
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append(neighbor)
|
||||
|
||||
# Build SolveCluster objects
|
||||
solve_clusters = []
|
||||
for i in order:
|
||||
bodies = clusters[i]
|
||||
indices = _constraints_for_bodies(G, bodies)
|
||||
boundary = bodies & articulation_points
|
||||
has_ground = bool(bodies & grounded_bodies)
|
||||
solve_clusters.append(
|
||||
SolveCluster(
|
||||
bodies=bodies,
|
||||
constraint_indices=indices,
|
||||
boundary_bodies=boundary,
|
||||
has_ground=has_ground,
|
||||
)
|
||||
)
|
||||
|
||||
return solve_clusters
|
||||
|
||||
|
||||
def _constraints_for_bodies(G: nx.MultiGraph, bodies: set[str]) -> list[int]:
|
||||
"""Collect constraint indices for edges where both endpoints are in bodies."""
|
||||
indices = []
|
||||
seen = set()
|
||||
for u, v, data in G.edges(data=True):
|
||||
idx = data["constraint_index"]
|
||||
if idx in seen:
|
||||
continue
|
||||
if u in bodies and v in bodies:
|
||||
seen.add(idx)
|
||||
indices.append(idx)
|
||||
return sorted(indices)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level decompose entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def decompose(
|
||||
constraints: list,
|
||||
grounded_bodies: set[str],
|
||||
) -> list[SolveCluster]:
|
||||
"""Full decomposition pipeline: graph → clusters → solve order.
|
||||
|
||||
Returns a list of SolveCluster in solve order (leaves first).
|
||||
If the system is a single cluster, returns a 1-element list.
|
||||
"""
|
||||
G = build_constraint_graph(constraints, grounded_bodies)
|
||||
|
||||
# Handle disconnected sub-assemblies
|
||||
all_clusters = []
|
||||
for component_nodes in nx.connected_components(G):
|
||||
sub = G.subgraph(component_nodes).copy()
|
||||
clusters, artic = find_clusters(sub)
|
||||
|
||||
if len(clusters) <= 1:
|
||||
# Single cluster in this component
|
||||
bodies = component_nodes if not clusters else clusters[0]
|
||||
indices = _constraints_for_bodies(sub, bodies)
|
||||
has_ground = bool(bodies & grounded_bodies)
|
||||
all_clusters.append(
|
||||
SolveCluster(
|
||||
bodies=set(bodies),
|
||||
constraint_indices=indices,
|
||||
boundary_bodies=set(),
|
||||
has_ground=has_ground,
|
||||
)
|
||||
)
|
||||
else:
|
||||
ordered = build_solve_order(sub, clusters, artic, grounded_bodies)
|
||||
all_clusters.extend(ordered)
|
||||
|
||||
return all_clusters
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cluster solver
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def solve_decomposed(
|
||||
clusters: list[SolveCluster],
|
||||
bodies: dict[str, "RigidBody"],
|
||||
constraint_objs: list["ConstraintBase"],
|
||||
constraint_indices_map: list[int],
|
||||
params: "ParamTable",
|
||||
) -> bool:
|
||||
"""Solve clusters in order, fixing boundary bodies between solves.
|
||||
|
||||
Args:
|
||||
clusters: SolveCluster list in solve order (from decompose()).
|
||||
bodies: part_id → RigidBody mapping.
|
||||
constraint_objs: constraint objects (parallel to constraint_indices_map).
|
||||
constraint_indices_map: for each constraint_obj, its index in ctx.constraints.
|
||||
params: shared ParamTable.
|
||||
|
||||
Returns True if all clusters converged.
|
||||
"""
|
||||
log.info(
|
||||
"solve_decomposed: %d clusters, %d bodies, %d constraints",
|
||||
len(clusters),
|
||||
len(bodies),
|
||||
len(constraint_objs),
|
||||
)
|
||||
|
||||
# Build reverse map: constraint_index → position in constraint_objs list
|
||||
idx_to_obj: dict[int, "ConstraintBase"] = {}
|
||||
for pos, ci in enumerate(constraint_indices_map):
|
||||
idx_to_obj[ci] = constraint_objs[pos]
|
||||
|
||||
solved_bodies: set[str] = set()
|
||||
all_converged = True
|
||||
|
||||
for cluster_idx, cluster in enumerate(clusters):
|
||||
# 1. Fix boundary bodies that were already solved
|
||||
fixed_boundary_params: list[str] = []
|
||||
for body_id in cluster.boundary_bodies:
|
||||
if body_id in solved_bodies:
|
||||
body = bodies[body_id]
|
||||
for pname in body._param_names:
|
||||
if not params.is_fixed(pname):
|
||||
params.fix(pname)
|
||||
fixed_boundary_params.append(pname)
|
||||
|
||||
# 2. Collect residuals for this cluster
|
||||
cluster_residuals = []
|
||||
for ci in cluster.constraint_indices:
|
||||
obj = idx_to_obj.get(ci)
|
||||
if obj is not None:
|
||||
cluster_residuals.extend(obj.residuals())
|
||||
|
||||
# 3. Add quat norm residuals for free, non-grounded bodies in this cluster
|
||||
quat_groups = []
|
||||
for body_id in cluster.bodies:
|
||||
body = bodies[body_id]
|
||||
if body.grounded:
|
||||
continue
|
||||
if body_id in cluster.boundary_bodies and body_id in solved_bodies:
|
||||
continue # Already fixed as boundary
|
||||
cluster_residuals.append(body.quat_norm_residual())
|
||||
quat_groups.append(body.quat_param_names())
|
||||
|
||||
# 4. Substitution pass (compiles fixed boundary params to constants)
|
||||
cluster_residuals = substitution_pass(cluster_residuals, params)
|
||||
|
||||
# 5. Newton solve (+ BFGS fallback)
|
||||
if cluster_residuals:
|
||||
log.debug(
|
||||
" cluster[%d]: %d bodies (%d boundary), %d constraints, %d residuals",
|
||||
cluster_idx,
|
||||
len(cluster.bodies),
|
||||
len(cluster.boundary_bodies),
|
||||
len(cluster.constraint_indices),
|
||||
len(cluster_residuals),
|
||||
)
|
||||
converged = newton_solve(
|
||||
cluster_residuals,
|
||||
params,
|
||||
quat_groups=quat_groups,
|
||||
max_iter=100,
|
||||
tol=1e-10,
|
||||
)
|
||||
if not converged:
|
||||
log.info(
|
||||
" cluster[%d]: Newton-Raphson failed, trying BFGS", cluster_idx
|
||||
)
|
||||
converged = bfgs_solve(
|
||||
cluster_residuals,
|
||||
params,
|
||||
quat_groups=quat_groups,
|
||||
max_iter=200,
|
||||
tol=1e-10,
|
||||
)
|
||||
if not converged:
|
||||
log.warning(" cluster[%d]: failed to converge", cluster_idx)
|
||||
all_converged = False
|
||||
|
||||
# 6. Mark this cluster's bodies as solved
|
||||
solved_bodies.update(cluster.bodies)
|
||||
|
||||
# 7. Unfix boundary params
|
||||
for pname in fixed_boundary_params:
|
||||
params.unfix(pname)
|
||||
|
||||
return all_converged
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pebble game integration (rigidity classification)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PEBBLE_MODULES_LOADED = False
|
||||
_PebbleGame3D = None
|
||||
_PebbleJointType = None
|
||||
_PebbleJoint = None
|
||||
|
||||
|
||||
def _load_pebble_modules():
|
||||
"""Lazily load PebbleGame3D and related types from GNN/solver/datagen/.
|
||||
|
||||
The GNN package has its own import structure (``from solver.datagen.types
|
||||
import ...``) that conflicts with the top-level module layout, so we
|
||||
register shim modules in ``sys.modules`` to make it work.
|
||||
"""
|
||||
global _PEBBLE_MODULES_LOADED, _PebbleGame3D, _PebbleJointType, _PebbleJoint
|
||||
if _PEBBLE_MODULES_LOADED:
|
||||
return
|
||||
|
||||
# Find GNN/solver/datagen relative to this package
|
||||
pkg_dir = Path(__file__).resolve().parent.parent # mods/solver/
|
||||
datagen_dir = pkg_dir / "GNN" / "solver" / "datagen"
|
||||
|
||||
if not datagen_dir.exists():
|
||||
log.warning("GNN/solver/datagen/ not found; pebble game unavailable")
|
||||
_PEBBLE_MODULES_LOADED = True
|
||||
return
|
||||
|
||||
# Register shim modules so ``from solver.datagen.types import ...`` works
|
||||
if "solver" not in sys.modules:
|
||||
sys.modules["solver"] = stdlib_types.ModuleType("solver")
|
||||
if "solver.datagen" not in sys.modules:
|
||||
dg = stdlib_types.ModuleType("solver.datagen")
|
||||
sys.modules["solver.datagen"] = dg
|
||||
sys.modules["solver"].datagen = dg # type: ignore[attr-defined]
|
||||
|
||||
# Load types.py
|
||||
types_path = datagen_dir / "types.py"
|
||||
spec_t = importlib.util.spec_from_file_location(
|
||||
"solver.datagen.types", str(types_path)
|
||||
)
|
||||
types_mod = importlib.util.module_from_spec(spec_t)
|
||||
sys.modules["solver.datagen.types"] = types_mod
|
||||
spec_t.loader.exec_module(types_mod)
|
||||
|
||||
# Load pebble_game.py
|
||||
pg_path = datagen_dir / "pebble_game.py"
|
||||
spec_p = importlib.util.spec_from_file_location(
|
||||
"solver.datagen.pebble_game", str(pg_path)
|
||||
)
|
||||
pg_mod = importlib.util.module_from_spec(spec_p)
|
||||
sys.modules["solver.datagen.pebble_game"] = pg_mod
|
||||
spec_p.loader.exec_module(pg_mod)
|
||||
|
||||
_PebbleGame3D = pg_mod.PebbleGame3D
|
||||
_PebbleJointType = types_mod.JointType
|
||||
_PebbleJoint = types_mod.Joint
|
||||
_PEBBLE_MODULES_LOADED = True
|
||||
|
||||
|
||||
# BaseJointKind name → PebbleGame JointType name.
|
||||
# Types not listed here use manual edge insertion with the residual count.
|
||||
_KIND_NAME_TO_PEBBLE_NAME: dict[str, str] = {
|
||||
"Fixed": "FIXED",
|
||||
"Coincident": "BALL", # Same DOF count (3)
|
||||
"Ball": "BALL",
|
||||
"Revolute": "REVOLUTE",
|
||||
"Cylindrical": "CYLINDRICAL",
|
||||
"Slider": "SLIDER",
|
||||
"Screw": "SCREW",
|
||||
"Universal": "UNIVERSAL",
|
||||
"Planar": "PLANAR",
|
||||
"Perpendicular": "PERPENDICULAR",
|
||||
"DistancePointPoint": "DISTANCE",
|
||||
}
|
||||
# Parallel: pebble game uses 3 DOF, but our solver uses 2.
|
||||
# We handle it with manual edge insertion.
|
||||
|
||||
# Types that need manual edge insertion (no direct JointType mapping,
|
||||
# or DOF mismatch like Parallel).
|
||||
_MANUAL_EDGE_TYPES: set[str] = {
|
||||
"Parallel", # 2 residuals, but JointType.PARALLEL = 3
|
||||
"Angle", # 1 residual, no JointType
|
||||
"Concentric", # 4 residuals, no JointType
|
||||
"Tangent", # 1 residual, no JointType
|
||||
"LineInPlane", # 2 residuals, no JointType
|
||||
"PointOnLine", # 2 residuals, no JointType
|
||||
"PointInPlane", # 1 residual, no JointType
|
||||
"Gear", # 1 residual, no JointType
|
||||
"RackPinion", # 1 residual, no JointType
|
||||
}
|
||||
|
||||
_GROUND_BODY_ID = -1
|
||||
|
||||
|
||||
def classify_cluster_rigidity(
|
||||
cluster: SolveCluster,
|
||||
constraint_graph: nx.MultiGraph,
|
||||
grounded_bodies: set[str],
|
||||
) -> str | None:
|
||||
"""Run pebble game on a cluster and return rigidity classification.
|
||||
|
||||
Returns one of: "well-constrained", "underconstrained",
|
||||
"overconstrained", "mixed", or None if pebble game unavailable.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
_load_pebble_modules()
|
||||
if _PebbleGame3D is None:
|
||||
return None
|
||||
|
||||
pg = _PebbleGame3D()
|
||||
|
||||
# Map string body IDs → integer IDs for pebble game
|
||||
body_list = sorted(cluster.bodies)
|
||||
body_to_int: dict[str, int] = {b: i for i, b in enumerate(body_list)}
|
||||
|
||||
for b in body_list:
|
||||
pg.add_body(body_to_int[b])
|
||||
|
||||
# Add virtual ground body if cluster has grounded bodies
|
||||
has_ground = bool(cluster.bodies & grounded_bodies)
|
||||
if has_ground:
|
||||
pg.add_body(_GROUND_BODY_ID)
|
||||
for b in cluster.bodies & grounded_bodies:
|
||||
ground_joint = _PebbleJoint(
|
||||
joint_id=-1,
|
||||
body_a=body_to_int[b],
|
||||
body_b=_GROUND_BODY_ID,
|
||||
joint_type=_PebbleJointType["FIXED"],
|
||||
anchor_a=np.zeros(3),
|
||||
anchor_b=np.zeros(3),
|
||||
)
|
||||
pg.add_joint(ground_joint)
|
||||
|
||||
# Add constraint edges
|
||||
joint_counter = 0
|
||||
zero = np.zeros(3)
|
||||
for u, v, data in constraint_graph.edges(data=True):
|
||||
if u not in cluster.bodies or v not in cluster.bodies:
|
||||
continue
|
||||
ci = data["constraint_index"]
|
||||
if ci not in cluster.constraint_indices:
|
||||
continue
|
||||
|
||||
# Determine the constraint kind name from the graph edge
|
||||
kind_name = data.get("kind_name", "")
|
||||
n_residuals = data.get("weight", 0)
|
||||
|
||||
if not kind_name or n_residuals == 0:
|
||||
continue
|
||||
|
||||
int_u = body_to_int[u]
|
||||
int_v = body_to_int[v]
|
||||
|
||||
pebble_name = _KIND_NAME_TO_PEBBLE_NAME.get(kind_name)
|
||||
if pebble_name and kind_name not in _MANUAL_EDGE_TYPES:
|
||||
# Direct JointType mapping
|
||||
jt = _PebbleJointType[pebble_name]
|
||||
joint = _PebbleJoint(
|
||||
joint_id=joint_counter,
|
||||
body_a=int_u,
|
||||
body_b=int_v,
|
||||
joint_type=jt,
|
||||
anchor_a=zero,
|
||||
anchor_b=zero,
|
||||
)
|
||||
pg.add_joint(joint)
|
||||
joint_counter += 1
|
||||
else:
|
||||
# Manual edge insertion: one DISTANCE edge per residual
|
||||
for _ in range(n_residuals):
|
||||
joint = _PebbleJoint(
|
||||
joint_id=joint_counter,
|
||||
body_a=int_u,
|
||||
body_b=int_v,
|
||||
joint_type=_PebbleJointType["DISTANCE"],
|
||||
anchor_a=zero,
|
||||
anchor_b=zero,
|
||||
)
|
||||
pg.add_joint(joint)
|
||||
joint_counter += 1
|
||||
|
||||
# Classify using raw pebble counts (adjusting for virtual ground)
|
||||
total_dof = pg.get_dof()
|
||||
redundant = pg.get_redundant_count()
|
||||
|
||||
# The virtual ground body contributes 6 pebbles that are never consumed.
|
||||
# Subtract them to get the effective DOF.
|
||||
if has_ground:
|
||||
total_dof -= 6 # virtual ground's unconstrained pebbles
|
||||
baseline = 0
|
||||
else:
|
||||
baseline = 6 # trivial rigid-body motion
|
||||
|
||||
if redundant > 0 and total_dof > baseline:
|
||||
return "mixed"
|
||||
elif redundant > 0:
|
||||
return "overconstrained"
|
||||
elif total_dof > baseline:
|
||||
return "underconstrained"
|
||||
elif total_dof == baseline:
|
||||
return "well-constrained"
|
||||
else:
|
||||
return "overconstrained"
|
||||
312
kindred_solver/diagnostics.py
Normal file
312
kindred_solver/diagnostics.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""Per-entity DOF diagnostics and overconstrained detection.
|
||||
|
||||
Provides per-body remaining degrees of freedom, human-readable free
|
||||
motion labels, and redundant/conflicting constraint identification.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .entities import RigidBody
|
||||
from .expr import Expr
|
||||
from .params import ParamTable
|
||||
|
||||
# -- Per-entity DOF -----------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityDOF:
|
||||
"""DOF report for a single entity (rigid body)."""
|
||||
|
||||
entity_id: str
|
||||
remaining_dof: int # 0 = well-constrained
|
||||
free_motions: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def per_entity_dof(
|
||||
residuals: list[Expr],
|
||||
params: ParamTable,
|
||||
bodies: dict[str, RigidBody],
|
||||
rank_tol: float = 1e-8,
|
||||
jac_exprs: "list[list[Expr]] | None" = None,
|
||||
) -> list[EntityDOF]:
|
||||
"""Compute remaining DOF for each non-grounded body.
|
||||
|
||||
For each body, extracts the Jacobian columns corresponding to its
|
||||
7 parameters, performs SVD to find constrained directions, and
|
||||
classifies null-space vectors as translations or rotations.
|
||||
"""
|
||||
free = params.free_names()
|
||||
n_res = len(residuals)
|
||||
env = params.get_env()
|
||||
|
||||
if n_res == 0:
|
||||
# No constraints — every free body has 6 DOF
|
||||
result = []
|
||||
for pid, body in bodies.items():
|
||||
if body.grounded:
|
||||
continue
|
||||
result.append(
|
||||
EntityDOF(
|
||||
entity_id=pid,
|
||||
remaining_dof=6,
|
||||
free_motions=[
|
||||
"translation along X",
|
||||
"translation along Y",
|
||||
"translation along Z",
|
||||
"rotation about X",
|
||||
"rotation about Y",
|
||||
"rotation about Z",
|
||||
],
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
# Build column index mapping: param_name -> column index in free list
|
||||
free_index = {name: i for i, name in enumerate(free)}
|
||||
|
||||
# Build full Jacobian (for efficiency, compute once)
|
||||
n_free = len(free)
|
||||
J_full = np.empty((n_res, n_free))
|
||||
if jac_exprs is not None:
|
||||
for i in range(n_res):
|
||||
for j in range(n_free):
|
||||
J_full[i, j] = jac_exprs[i][j].eval(env)
|
||||
else:
|
||||
for i, r in enumerate(residuals):
|
||||
for j, name in enumerate(free):
|
||||
J_full[i, j] = r.diff(name).simplify().eval(env)
|
||||
|
||||
result = []
|
||||
for pid, body in bodies.items():
|
||||
if body.grounded:
|
||||
continue
|
||||
|
||||
# Find column indices for this body's params
|
||||
pfx = pid + "/"
|
||||
body_param_names = [
|
||||
pfx + "tx",
|
||||
pfx + "ty",
|
||||
pfx + "tz",
|
||||
pfx + "qw",
|
||||
pfx + "qx",
|
||||
pfx + "qy",
|
||||
pfx + "qz",
|
||||
]
|
||||
col_indices = [free_index[n] for n in body_param_names if n in free_index]
|
||||
|
||||
if not col_indices:
|
||||
# All params fixed (shouldn't happen for non-grounded, but be safe)
|
||||
result.append(EntityDOF(entity_id=pid, remaining_dof=0))
|
||||
continue
|
||||
|
||||
# Extract submatrix: all residual rows, only this body's columns
|
||||
J_sub = J_full[:, col_indices]
|
||||
|
||||
# SVD
|
||||
U, sv, Vt = np.linalg.svd(J_sub, full_matrices=True)
|
||||
constrained = int(np.sum(sv > rank_tol))
|
||||
|
||||
# Subtract 1 for the quaternion unit-norm constraint (already in residuals)
|
||||
# The quat norm residual constrains 1 direction in the 7-D param space,
|
||||
# so effective body DOF = 7 - 1 - constrained_by_other_constraints.
|
||||
# But the quat norm IS one of the residual rows, so it's already counted
|
||||
# in `constrained`. So: remaining = len(col_indices) - constrained
|
||||
# But the quat norm takes 1 from 7 → 6 geometric DOF, and constrained
|
||||
# includes the quat norm row. So remaining = 7 - constrained, which gives
|
||||
# geometric remaining DOF directly.
|
||||
remaining = len(col_indices) - constrained
|
||||
|
||||
# Classify null-space vectors as free motions
|
||||
free_motions = []
|
||||
if remaining > 0 and Vt.shape[0] > constrained:
|
||||
null_space = Vt[constrained:] # rows = null vectors in param space
|
||||
|
||||
# Map column indices back to param types
|
||||
param_types = []
|
||||
for n in body_param_names:
|
||||
if n in free_index:
|
||||
if n.endswith(("/tx", "/ty", "/tz")):
|
||||
param_types.append("t")
|
||||
else:
|
||||
param_types.append("q")
|
||||
|
||||
for null_vec in null_space:
|
||||
label = _classify_motion(
|
||||
null_vec, param_types, body_param_names, free_index
|
||||
)
|
||||
if label:
|
||||
free_motions.append(label)
|
||||
|
||||
result.append(
|
||||
EntityDOF(
|
||||
entity_id=pid,
|
||||
remaining_dof=remaining,
|
||||
free_motions=free_motions,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _classify_motion(
|
||||
null_vec: np.ndarray,
|
||||
param_types: list[str],
|
||||
body_param_names: list[str],
|
||||
free_index: dict[str, int],
|
||||
) -> str:
|
||||
"""Classify a null-space vector as translation, rotation, or helical."""
|
||||
# Split components into translation and rotation parts
|
||||
trans_indices = [i for i, t in enumerate(param_types) if t == "t"]
|
||||
rot_indices = [i for i, t in enumerate(param_types) if t == "q"]
|
||||
|
||||
trans_norm = np.linalg.norm(null_vec[trans_indices]) if trans_indices else 0.0
|
||||
rot_norm = np.linalg.norm(null_vec[rot_indices]) if rot_indices else 0.0
|
||||
|
||||
total = trans_norm + rot_norm
|
||||
if total < 1e-14:
|
||||
return ""
|
||||
|
||||
trans_frac = trans_norm / total
|
||||
rot_frac = rot_norm / total
|
||||
|
||||
# Determine dominant axis
|
||||
if trans_frac > 0.8:
|
||||
# Pure translation
|
||||
axis = _dominant_axis(null_vec, trans_indices)
|
||||
return f"translation along {axis}"
|
||||
elif rot_frac > 0.8:
|
||||
# Pure rotation
|
||||
axis = _dominant_axis(null_vec, rot_indices)
|
||||
return f"rotation about {axis}"
|
||||
else:
|
||||
# Mixed — helical
|
||||
axis = _dominant_axis(null_vec, trans_indices)
|
||||
return f"helical motion along {axis}"
|
||||
|
||||
|
||||
def _dominant_axis(vec: np.ndarray, indices: list[int]) -> str:
|
||||
"""Find the dominant axis (X/Y/Z) among the given component indices."""
|
||||
if not indices:
|
||||
return "?"
|
||||
components = np.abs(vec[indices])
|
||||
# Map to axis names — first 3 in group are X/Y/Z
|
||||
axis_names = ["X", "Y", "Z"]
|
||||
if len(components) >= 3:
|
||||
idx = int(np.argmax(components[:3]))
|
||||
return axis_names[idx]
|
||||
elif len(components) == 1:
|
||||
return axis_names[0]
|
||||
else:
|
||||
idx = int(np.argmax(components))
|
||||
return axis_names[min(idx, 2)]
|
||||
|
||||
|
||||
# -- Overconstrained detection ------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConstraintDiag:
|
||||
"""Diagnostic for a single constraint."""
|
||||
|
||||
constraint_index: int
|
||||
kind: str # "redundant" | "conflicting"
|
||||
detail: str
|
||||
|
||||
|
||||
def find_overconstrained(
|
||||
residuals: list[Expr],
|
||||
params: ParamTable,
|
||||
residual_ranges: list[tuple[int, int, int]],
|
||||
rank_tol: float = 1e-8,
|
||||
jac_exprs: "list[list[Expr]] | None" = None,
|
||||
) -> list[ConstraintDiag]:
|
||||
"""Identify redundant and conflicting constraints.
|
||||
|
||||
Algorithm (following SolvSpace's FindWhichToRemoveToFixJacobian):
|
||||
1. Build full Jacobian J, compute rank.
|
||||
2. If rank == n_residuals, not overconstrained — return empty.
|
||||
3. For each constraint: remove its rows, check if rank is preserved
|
||||
→ if so, the constraint is **redundant**.
|
||||
4. Compute left null space, project residual vector F → if a
|
||||
constraint's residuals contribute to this projection, it is
|
||||
**conflicting** (redundant + unsatisfied).
|
||||
"""
|
||||
free = params.free_names()
|
||||
n_free = len(free)
|
||||
n_res = len(residuals)
|
||||
|
||||
if n_free == 0 or n_res == 0:
|
||||
return []
|
||||
|
||||
env = params.get_env()
|
||||
|
||||
# Build Jacobian and residual vector
|
||||
J = np.empty((n_res, n_free))
|
||||
r_vec = np.empty(n_res)
|
||||
for i, r in enumerate(residuals):
|
||||
r_vec[i] = r.eval(env)
|
||||
if jac_exprs is not None:
|
||||
for i in range(n_res):
|
||||
for j in range(n_free):
|
||||
J[i, j] = jac_exprs[i][j].eval(env)
|
||||
else:
|
||||
for i, r in enumerate(residuals):
|
||||
for j, name in enumerate(free):
|
||||
J[i, j] = r.diff(name).simplify().eval(env)
|
||||
|
||||
# Full rank
|
||||
sv_full = np.linalg.svd(J, compute_uv=False)
|
||||
full_rank = int(np.sum(sv_full > rank_tol))
|
||||
|
||||
if full_rank >= n_res:
|
||||
return [] # not overconstrained
|
||||
|
||||
# Left null space: columns of U beyond rank
|
||||
U, sv, Vt = np.linalg.svd(J, full_matrices=True)
|
||||
left_null = U[:, full_rank:] # shape (n_res, n_res - rank)
|
||||
|
||||
# Project residual onto left null space
|
||||
null_residual = left_null.T @ r_vec # shape (n_res - rank,)
|
||||
residual_projection = left_null @ null_residual # back to residual space
|
||||
|
||||
diags = []
|
||||
for start, end, c_idx in residual_ranges:
|
||||
# Remove this constraint's rows and check rank
|
||||
mask = np.ones(n_res, dtype=bool)
|
||||
mask[start:end] = False
|
||||
J_reduced = J[mask]
|
||||
|
||||
if J_reduced.shape[0] == 0:
|
||||
continue
|
||||
|
||||
sv_reduced = np.linalg.svd(J_reduced, compute_uv=False)
|
||||
reduced_rank = int(np.sum(sv_reduced > rank_tol))
|
||||
|
||||
if reduced_rank >= full_rank:
|
||||
# Removing this constraint preserves rank → redundant
|
||||
# Check if it's also conflicting (contributes to unsatisfied null projection)
|
||||
constraint_proj = np.linalg.norm(residual_projection[start:end])
|
||||
if constraint_proj > rank_tol:
|
||||
kind = "conflicting"
|
||||
detail = (
|
||||
f"Constraint {c_idx} is conflicting (redundant and unsatisfied)"
|
||||
)
|
||||
else:
|
||||
kind = "redundant"
|
||||
detail = (
|
||||
f"Constraint {c_idx} is redundant (can be removed without effect)"
|
||||
)
|
||||
diags.append(
|
||||
ConstraintDiag(
|
||||
constraint_index=c_idx,
|
||||
kind=kind,
|
||||
detail=detail,
|
||||
)
|
||||
)
|
||||
|
||||
return diags
|
||||
@@ -14,11 +14,15 @@ def count_dof(
|
||||
residuals: List[Expr],
|
||||
params: ParamTable,
|
||||
rank_tol: float = 1e-8,
|
||||
jac_exprs: "List[List[Expr]] | None" = None,
|
||||
) -> int:
|
||||
"""Compute DOF = n_free_params - rank(Jacobian).
|
||||
|
||||
Evaluates the Jacobian numerically at the current parameter values
|
||||
and computes its rank via SVD.
|
||||
|
||||
When *jac_exprs* is provided, reuses the pre-built symbolic
|
||||
Jacobian instead of re-differentiating every residual.
|
||||
"""
|
||||
free = params.free_names()
|
||||
n_free = len(free)
|
||||
@@ -32,9 +36,14 @@ def count_dof(
|
||||
env = params.get_env()
|
||||
|
||||
J = np.empty((n_res, n_free))
|
||||
for i, r in enumerate(residuals):
|
||||
for j, name in enumerate(free):
|
||||
J[i, j] = r.diff(name).simplify().eval(env)
|
||||
if jac_exprs is not None:
|
||||
for i in range(n_res):
|
||||
for j in range(n_free):
|
||||
J[i, j] = jac_exprs[i][j].eval(env)
|
||||
else:
|
||||
for i, r in enumerate(residuals):
|
||||
for j, name in enumerate(free):
|
||||
J[i, j] = r.diff(name).simplify().eval(env)
|
||||
|
||||
if J.size == 0:
|
||||
return n_free
|
||||
|
||||
@@ -24,6 +24,16 @@ class Expr:
|
||||
"""Return the set of variable names in this expression."""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_code(self) -> str:
|
||||
"""Emit a Python arithmetic expression string.
|
||||
|
||||
The returned string, when evaluated with a dict ``env`` mapping
|
||||
parameter names to floats (and ``_sin``, ``_cos``, ``_sqrt``
|
||||
bound to their ``math`` equivalents), produces the same result
|
||||
as ``self.eval(env)``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# -- operator overloads --------------------------------------------------
|
||||
|
||||
def __add__(self, other):
|
||||
@@ -90,6 +100,9 @@ class Const(Expr):
|
||||
def vars(self):
|
||||
return set()
|
||||
|
||||
def to_code(self):
|
||||
return repr(self.value)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Const({self.value})"
|
||||
|
||||
@@ -118,6 +131,9 @@ class Var(Expr):
|
||||
def vars(self):
|
||||
return {self.name}
|
||||
|
||||
def to_code(self):
|
||||
return f"env[{self.name!r}]"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Var({self.name!r})"
|
||||
|
||||
@@ -154,6 +170,9 @@ class Neg(Expr):
|
||||
def vars(self):
|
||||
return self.child.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"(-{self.child.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Neg({self.child!r})"
|
||||
|
||||
@@ -180,6 +199,9 @@ class Sin(Expr):
|
||||
def vars(self):
|
||||
return self.child.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"_sin({self.child.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Sin({self.child!r})"
|
||||
|
||||
@@ -206,6 +228,9 @@ class Cos(Expr):
|
||||
def vars(self):
|
||||
return self.child.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"_cos({self.child.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Cos({self.child!r})"
|
||||
|
||||
@@ -232,6 +257,9 @@ class Sqrt(Expr):
|
||||
def vars(self):
|
||||
return self.child.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"_sqrt({self.child.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Sqrt({self.child!r})"
|
||||
|
||||
@@ -266,6 +294,9 @@ class Add(Expr):
|
||||
def vars(self):
|
||||
return self.a.vars() | self.b.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"({self.a.to_code()} + {self.b.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Add({self.a!r}, {self.b!r})"
|
||||
|
||||
@@ -297,6 +328,9 @@ class Sub(Expr):
|
||||
def vars(self):
|
||||
return self.a.vars() | self.b.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"({self.a.to_code()} - {self.b.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Sub({self.a!r}, {self.b!r})"
|
||||
|
||||
@@ -337,6 +371,9 @@ class Mul(Expr):
|
||||
def vars(self):
|
||||
return self.a.vars() | self.b.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"({self.a.to_code()} * {self.b.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Mul({self.a!r}, {self.b!r})"
|
||||
|
||||
@@ -372,6 +409,9 @@ class Div(Expr):
|
||||
def vars(self):
|
||||
return self.a.vars() | self.b.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"({self.a.to_code()} / {self.b.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Div({self.a!r}, {self.b!r})"
|
||||
|
||||
@@ -414,6 +454,9 @@ class Pow(Expr):
|
||||
def vars(self):
|
||||
return self.base.vars() | self.exp.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"({self.base.to_code()} ** {self.exp.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Pow({self.base!r}, {self.exp!r})"
|
||||
|
||||
|
||||
131
kindred_solver/geometry.py
Normal file
131
kindred_solver/geometry.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Geometric helper functions for constraint equations.
|
||||
|
||||
Provides Expr-level vector operations and marker axis extraction.
|
||||
All functions work with Expr triples (tuples of 3 Expr nodes)
|
||||
representing 3D vectors in world coordinates.
|
||||
|
||||
Marker convention (from Types.h): the marker's Z-axis defines the
|
||||
constraint direction (hinge axis, face normal, line direction, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .entities import RigidBody
|
||||
from .expr import Const, Expr
|
||||
from .quat import quat_rotate
|
||||
|
||||
# Type alias for an Expr triple (3D vector)
|
||||
Vec3 = tuple[Expr, Expr, Expr]
|
||||
|
||||
|
||||
# -- Marker axis extraction ---------------------------------------------------
|
||||
|
||||
|
||||
def _composed_quat(
|
||||
body: RigidBody,
|
||||
marker_quat: tuple[float, float, float, float],
|
||||
) -> tuple[Expr, Expr, Expr, Expr]:
|
||||
"""Compute q_total = q_body * q_marker as Expr quaternion.
|
||||
|
||||
q_body comes from the body's Var params; q_marker is constant.
|
||||
"""
|
||||
bw, bx, by, bz = body.qw, body.qx, body.qy, body.qz
|
||||
mw, mx, my, mz = (Const(v) for v in marker_quat)
|
||||
# Hamilton product: body * marker
|
||||
rw = bw * mw - bx * mx - by * my - bz * mz
|
||||
rx = bw * mx + bx * mw + by * mz - bz * my
|
||||
ry = bw * my - bx * mz + by * mw + bz * mx
|
||||
rz = bw * mz + bx * my - by * mx + bz * mw
|
||||
return rw, rx, ry, rz
|
||||
|
||||
|
||||
def marker_z_axis(
|
||||
body: RigidBody,
|
||||
marker_quat: tuple[float, float, float, float],
|
||||
) -> Vec3:
|
||||
"""World-frame Z-axis of a marker on a body.
|
||||
|
||||
Computes rotate(q_body * q_marker, [0, 0, 1]).
|
||||
"""
|
||||
qw, qx, qy, qz = _composed_quat(body, marker_quat)
|
||||
return quat_rotate(qw, qx, qy, qz, Const(0.0), Const(0.0), Const(1.0))
|
||||
|
||||
|
||||
def marker_x_axis(
|
||||
body: RigidBody,
|
||||
marker_quat: tuple[float, float, float, float],
|
||||
) -> Vec3:
|
||||
"""World-frame X-axis of a marker on a body.
|
||||
|
||||
Computes rotate(q_body * q_marker, [1, 0, 0]).
|
||||
"""
|
||||
qw, qx, qy, qz = _composed_quat(body, marker_quat)
|
||||
return quat_rotate(qw, qx, qy, qz, Const(1.0), Const(0.0), Const(0.0))
|
||||
|
||||
|
||||
def marker_y_axis(
|
||||
body: RigidBody,
|
||||
marker_quat: tuple[float, float, float, float],
|
||||
) -> Vec3:
|
||||
"""World-frame Y-axis of a marker on a body.
|
||||
|
||||
Computes rotate(q_body * q_marker, [0, 1, 0]).
|
||||
"""
|
||||
qw, qx, qy, qz = _composed_quat(body, marker_quat)
|
||||
return quat_rotate(qw, qx, qy, qz, Const(0.0), Const(1.0), Const(0.0))
|
||||
|
||||
|
||||
# -- Vector operations on Expr triples ----------------------------------------
|
||||
|
||||
|
||||
def dot3(a: Vec3, b: Vec3) -> Expr:
|
||||
"""Dot product of two Expr triples."""
|
||||
return a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
|
||||
|
||||
|
||||
def cross3(a: Vec3, b: Vec3) -> Vec3:
|
||||
"""Cross product of two Expr triples."""
|
||||
return (
|
||||
a[1] * b[2] - a[2] * b[1],
|
||||
a[2] * b[0] - a[0] * b[2],
|
||||
a[0] * b[1] - a[1] * b[0],
|
||||
)
|
||||
|
||||
|
||||
def sub3(a: Vec3, b: Vec3) -> Vec3:
|
||||
"""Vector subtraction a - b."""
|
||||
return (a[0] - b[0], a[1] - b[1], a[2] - b[2])
|
||||
|
||||
|
||||
# -- Geometric primitives -----------------------------------------------------
|
||||
|
||||
|
||||
def point_plane_distance(
|
||||
point: Vec3,
|
||||
plane_origin: Vec3,
|
||||
normal: Vec3,
|
||||
) -> Expr:
|
||||
"""Signed distance from point to plane defined by origin + normal.
|
||||
|
||||
Returns (point - plane_origin) . normal
|
||||
"""
|
||||
d = sub3(point, plane_origin)
|
||||
return dot3(d, normal)
|
||||
|
||||
|
||||
def point_line_perp_components(
|
||||
point: Vec3,
|
||||
line_origin: Vec3,
|
||||
line_dir: Vec3,
|
||||
) -> tuple[Expr, Expr, Expr]:
|
||||
"""Three perpendicular-distance components from point to line.
|
||||
|
||||
The line passes through line_origin along line_dir.
|
||||
Returns all 3 components of (point - line_origin) x line_dir,
|
||||
which are zero when the point lies on the line. Only 2 of 3 are
|
||||
independent, but using all 3 avoids a singularity when the line
|
||||
direction aligns with a coordinate axis.
|
||||
"""
|
||||
d = sub3(point, line_origin)
|
||||
cx, cy, cz = cross3(d, line_dir)
|
||||
return cx, cy, cz
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import List
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -17,6 +17,10 @@ def newton_solve(
|
||||
quat_groups: List[tuple[str, str, str, str]] | None = None,
|
||||
max_iter: int = 50,
|
||||
tol: float = 1e-10,
|
||||
post_step: "Callable[[ParamTable], None] | None" = None,
|
||||
weight_vector: "np.ndarray | None" = None,
|
||||
jac_exprs: "List[List[Expr]] | None" = None,
|
||||
compiled_eval: "Callable | None" = None,
|
||||
) -> bool:
|
||||
"""Solve ``residuals == 0`` by Newton-Raphson.
|
||||
|
||||
@@ -33,6 +37,20 @@ def newton_solve(
|
||||
Maximum Newton iterations.
|
||||
tol:
|
||||
Convergence threshold on ``||r||``.
|
||||
post_step:
|
||||
Optional callback invoked after each parameter update, before
|
||||
quaternion renormalization. Used for half-space correction.
|
||||
weight_vector:
|
||||
Optional 1-D array of length ``n_free``. When provided, the
|
||||
lstsq step is column-scaled to produce the weighted
|
||||
minimum-norm solution (prefer small movements in
|
||||
high-weight parameters).
|
||||
jac_exprs:
|
||||
Pre-built symbolic Jacobian (list-of-lists of Expr). When
|
||||
provided, skips the ``diff().simplify()`` step.
|
||||
compiled_eval:
|
||||
Pre-compiled evaluation function from :mod:`codegen`. When
|
||||
provided, uses flat compiled code instead of tree-walk eval.
|
||||
|
||||
Returns True if converged within *max_iter* iterations.
|
||||
"""
|
||||
@@ -43,45 +61,72 @@ def newton_solve(
|
||||
if n_free == 0 or n_res == 0:
|
||||
return True
|
||||
|
||||
# Build symbolic Jacobian once (list-of-lists of simplified Expr)
|
||||
jac_exprs: List[List[Expr]] = []
|
||||
for r in residuals:
|
||||
row = []
|
||||
for name in free:
|
||||
row.append(r.diff(name).simplify())
|
||||
jac_exprs.append(row)
|
||||
# Build symbolic Jacobian once (or reuse pre-built)
|
||||
if jac_exprs is None:
|
||||
jac_exprs = []
|
||||
for r in residuals:
|
||||
row = []
|
||||
for name in free:
|
||||
row.append(r.diff(name).simplify())
|
||||
jac_exprs.append(row)
|
||||
|
||||
# Try compilation if not provided
|
||||
if compiled_eval is None:
|
||||
from .codegen import try_compile_system
|
||||
|
||||
compiled_eval = try_compile_system(residuals, jac_exprs, n_res, n_free)
|
||||
|
||||
# Pre-allocate arrays reused across iterations
|
||||
r_vec = np.empty(n_res)
|
||||
J = np.zeros((n_res, n_free))
|
||||
|
||||
for _it in range(max_iter):
|
||||
env = params.get_env()
|
||||
if compiled_eval is not None:
|
||||
J[:] = 0.0
|
||||
compiled_eval(params.env_ref(), r_vec, J)
|
||||
else:
|
||||
env = params.get_env()
|
||||
for i, r in enumerate(residuals):
|
||||
r_vec[i] = r.eval(env)
|
||||
for i in range(n_res):
|
||||
for j in range(n_free):
|
||||
J[i, j] = jac_exprs[i][j].eval(env)
|
||||
|
||||
# Evaluate residual vector
|
||||
r_vec = np.array([r.eval(env) for r in residuals])
|
||||
r_norm = np.linalg.norm(r_vec)
|
||||
if r_norm < tol:
|
||||
return True
|
||||
|
||||
# Evaluate Jacobian matrix
|
||||
J = np.empty((n_res, n_free))
|
||||
for i in range(n_res):
|
||||
for j in range(n_free):
|
||||
J[i, j] = jac_exprs[i][j].eval(env)
|
||||
|
||||
# Solve J @ dx = -r (least-squares handles rank-deficient)
|
||||
dx, _, _, _ = np.linalg.lstsq(J, -r_vec, rcond=None)
|
||||
if weight_vector is not None:
|
||||
# Column-scale J by W^{-1/2} for weighted minimum-norm
|
||||
w_inv_sqrt = 1.0 / np.sqrt(weight_vector)
|
||||
J_scaled = J * w_inv_sqrt[np.newaxis, :]
|
||||
dx_scaled, _, _, _ = np.linalg.lstsq(J_scaled, -r_vec, rcond=None)
|
||||
dx = dx_scaled * w_inv_sqrt
|
||||
else:
|
||||
dx, _, _, _ = np.linalg.lstsq(J, -r_vec, rcond=None)
|
||||
|
||||
# Update parameters
|
||||
x = params.get_free_vector()
|
||||
x += dx
|
||||
params.set_free_vector(x)
|
||||
|
||||
# Half-space correction (before quat renormalization)
|
||||
if post_step:
|
||||
post_step(params)
|
||||
|
||||
# Re-normalize quaternions
|
||||
if quat_groups:
|
||||
_renormalize_quats(params, quat_groups)
|
||||
|
||||
# Check final residual
|
||||
env = params.get_env()
|
||||
r_vec = np.array([r.eval(env) for r in residuals])
|
||||
return np.linalg.norm(r_vec) < tol
|
||||
if compiled_eval is not None:
|
||||
compiled_eval(params.env_ref(), r_vec, J)
|
||||
else:
|
||||
env = params.get_env()
|
||||
for i, r in enumerate(residuals):
|
||||
r_vec[i] = r.eval(env)
|
||||
return bool(np.linalg.norm(r_vec) < tol)
|
||||
|
||||
|
||||
def _renormalize_quats(
|
||||
|
||||
@@ -49,10 +49,25 @@ class ParamTable:
|
||||
if name in self._free_order:
|
||||
self._free_order.remove(name)
|
||||
|
||||
def unfix(self, name: str):
|
||||
"""Restore a fixed parameter to free status."""
|
||||
if name in self._fixed:
|
||||
self._fixed.discard(name)
|
||||
if name not in self._free_order:
|
||||
self._free_order.append(name)
|
||||
|
||||
def get_env(self) -> Dict[str, float]:
|
||||
"""Return a snapshot of all current values (for Expr.eval)."""
|
||||
return dict(self._values)
|
||||
|
||||
def env_ref(self) -> Dict[str, float]:
|
||||
"""Return a direct reference to the internal values dict.
|
||||
|
||||
Faster than :meth:`get_env` (no copy). Safe when the caller
|
||||
only reads during evaluation and mutates via :meth:`set_free_vector`.
|
||||
"""
|
||||
return self._values
|
||||
|
||||
def free_names(self) -> List[str]:
|
||||
"""Return ordered list of free (non-fixed) parameter names."""
|
||||
return list(self._free_order)
|
||||
@@ -74,3 +89,26 @@ class ParamTable:
|
||||
"""Bulk-update free parameters from a 1-D array."""
|
||||
for i, name in enumerate(self._free_order):
|
||||
self._values[name] = float(vec[i])
|
||||
|
||||
def snapshot(self) -> Dict[str, float]:
|
||||
"""Capture current values as a checkpoint."""
|
||||
return dict(self._values)
|
||||
|
||||
def restore(self, snap: Dict[str, float]):
|
||||
"""Restore parameter values from a checkpoint."""
|
||||
for name, val in snap.items():
|
||||
if name in self._values:
|
||||
self._values[name] = val
|
||||
|
||||
def movement_cost(
|
||||
self,
|
||||
start: Dict[str, float],
|
||||
weights: Dict[str, float] | None = None,
|
||||
) -> float:
|
||||
"""Weighted sum of squared displacements from start."""
|
||||
cost = 0.0
|
||||
for name in self._free_order:
|
||||
w = weights.get(name, 1.0) if weights else 1.0
|
||||
delta = self._values[name] - start.get(name, self._values[name])
|
||||
cost += delta * delta * w
|
||||
return cost
|
||||
|
||||
667
kindred_solver/preference.py
Normal file
667
kindred_solver/preference.py
Normal file
@@ -0,0 +1,667 @@
|
||||
"""Solution preference: half-space tracking and minimum-movement weighting.
|
||||
|
||||
Half-space tracking preserves the initial configuration branch across
|
||||
Newton iterations. For constraints with multiple valid solutions
|
||||
(e.g. distance can be satisfied on either side), we record which
|
||||
"half-space" the initial state lives in and correct the solver step
|
||||
if it crosses to the wrong branch.
|
||||
|
||||
Minimum-movement weighting scales the Newton/BFGS step so that
|
||||
quaternion parameters (rotation) are penalised more than translation
|
||||
parameters, yielding the physically-nearest solution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .constraints import (
|
||||
AngleConstraint,
|
||||
ConcentricConstraint,
|
||||
ConstraintBase,
|
||||
CylindricalConstraint,
|
||||
DistancePointPointConstraint,
|
||||
LineInPlaneConstraint,
|
||||
ParallelConstraint,
|
||||
PerpendicularConstraint,
|
||||
PlanarConstraint,
|
||||
PointInPlaneConstraint,
|
||||
RevoluteConstraint,
|
||||
ScrewConstraint,
|
||||
SliderConstraint,
|
||||
UniversalConstraint,
|
||||
)
|
||||
from .geometry import cross3, dot3, marker_z_axis, point_plane_distance, sub3
|
||||
from .params import ParamTable
|
||||
|
||||
|
||||
@dataclass
|
||||
class HalfSpace:
|
||||
"""Tracks which branch of a branching constraint the solution should stay in."""
|
||||
|
||||
constraint_index: int # index in ctx.constraints
|
||||
reference_sign: float # +1.0 or -1.0, captured at setup
|
||||
indicator_fn: Callable[[dict[str, float]], float] # returns signed value
|
||||
param_names: list[str] = field(default_factory=list) # params to flip
|
||||
correction_fn: Callable[[ParamTable, float], None] | None = None
|
||||
|
||||
|
||||
def compute_half_spaces(
|
||||
constraint_objs: list[ConstraintBase],
|
||||
constraint_indices: list[int],
|
||||
params: ParamTable,
|
||||
) -> list[HalfSpace]:
|
||||
"""Build half-space trackers for all branching constraints.
|
||||
|
||||
Evaluates each constraint's indicator function at the current
|
||||
parameter values to capture the reference sign.
|
||||
"""
|
||||
env = params.get_env()
|
||||
half_spaces: list[HalfSpace] = []
|
||||
|
||||
for i, obj in enumerate(constraint_objs):
|
||||
hs = _build_half_space(obj, constraint_indices[i], env, params)
|
||||
if hs is not None:
|
||||
half_spaces.append(hs)
|
||||
|
||||
return half_spaces
|
||||
|
||||
|
||||
def apply_half_space_correction(
|
||||
params: ParamTable,
|
||||
half_spaces: list[HalfSpace],
|
||||
) -> None:
|
||||
"""Check each half-space and correct if the solver crossed a branch.
|
||||
|
||||
Called as a post_step callback from newton_solve.
|
||||
"""
|
||||
if not half_spaces:
|
||||
return
|
||||
|
||||
env = params.get_env()
|
||||
for hs in half_spaces:
|
||||
current_val = hs.indicator_fn(env)
|
||||
current_sign = (
|
||||
math.copysign(1.0, current_val)
|
||||
if abs(current_val) > 1e-14
|
||||
else hs.reference_sign
|
||||
)
|
||||
if current_sign != hs.reference_sign and hs.correction_fn is not None:
|
||||
hs.correction_fn(params, current_val)
|
||||
# Re-read env after correction for subsequent half-spaces
|
||||
env = params.get_env()
|
||||
|
||||
|
||||
def _build_half_space(
|
||||
obj: ConstraintBase,
|
||||
constraint_idx: int,
|
||||
env: dict[str, float],
|
||||
params: ParamTable,
|
||||
) -> HalfSpace | None:
|
||||
"""Build a HalfSpace for a branching constraint, or None if not branching."""
|
||||
|
||||
if isinstance(obj, DistancePointPointConstraint) and obj.distance > 0:
|
||||
return _distance_half_space(obj, constraint_idx, env, params)
|
||||
|
||||
if isinstance(obj, ParallelConstraint):
|
||||
return _parallel_half_space(obj, constraint_idx, env, params)
|
||||
|
||||
if isinstance(obj, AngleConstraint):
|
||||
return _angle_half_space(obj, constraint_idx, env, params)
|
||||
|
||||
if isinstance(obj, PerpendicularConstraint):
|
||||
return _perpendicular_half_space(obj, constraint_idx, env, params)
|
||||
|
||||
if isinstance(obj, PlanarConstraint):
|
||||
return _planar_half_space(obj, constraint_idx, env, params)
|
||||
|
||||
if isinstance(obj, RevoluteConstraint):
|
||||
return _revolute_half_space(obj, constraint_idx, env, params)
|
||||
|
||||
if isinstance(obj, ConcentricConstraint):
|
||||
return _concentric_half_space(obj, constraint_idx, env, params)
|
||||
|
||||
if isinstance(obj, PointInPlaneConstraint):
|
||||
return _point_in_plane_half_space(obj, constraint_idx, env, params)
|
||||
|
||||
if isinstance(obj, LineInPlaneConstraint):
|
||||
return _line_in_plane_half_space(obj, constraint_idx, env, params)
|
||||
|
||||
if isinstance(obj, CylindricalConstraint):
|
||||
return _axis_direction_half_space(obj, constraint_idx, env)
|
||||
|
||||
if isinstance(obj, SliderConstraint):
|
||||
return _axis_direction_half_space(obj, constraint_idx, env)
|
||||
|
||||
if isinstance(obj, ScrewConstraint):
|
||||
return _axis_direction_half_space(obj, constraint_idx, env)
|
||||
|
||||
if isinstance(obj, UniversalConstraint):
|
||||
return _universal_half_space(obj, constraint_idx, env)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _distance_half_space(
|
||||
obj: DistancePointPointConstraint,
|
||||
constraint_idx: int,
|
||||
env: dict[str, float],
|
||||
params: ParamTable,
|
||||
) -> HalfSpace | None:
|
||||
"""Half-space for DistancePointPoint: track displacement direction.
|
||||
|
||||
The indicator is the dot product of the current displacement with
|
||||
the reference displacement direction. If the solver flips to the
|
||||
opposite side, we reflect the moving body's position.
|
||||
"""
|
||||
p_i, p_j = obj.world_points()
|
||||
|
||||
# Evaluate reference displacement direction
|
||||
dx = p_i[0].eval(env) - p_j[0].eval(env)
|
||||
dy = p_i[1].eval(env) - p_j[1].eval(env)
|
||||
dz = p_i[2].eval(env) - p_j[2].eval(env)
|
||||
dist = math.sqrt(dx * dx + dy * dy + dz * dz)
|
||||
|
||||
if dist < 1e-14:
|
||||
return None # points coincident, no branch to track
|
||||
|
||||
# Reference unit direction
|
||||
nx, ny, nz = dx / dist, dy / dist, dz / dist
|
||||
|
||||
# Build indicator: dot(displacement, reference_direction)
|
||||
# Use Expr evaluation for speed
|
||||
disp_x, disp_y, disp_z = p_i[0] - p_j[0], p_i[1] - p_j[1], p_i[2] - p_j[2]
|
||||
|
||||
def indicator(e: dict[str, float]) -> float:
|
||||
return disp_x.eval(e) * nx + disp_y.eval(e) * ny + disp_z.eval(e) * nz
|
||||
|
||||
ref_sign = math.copysign(1.0, indicator(env))
|
||||
|
||||
# Correction: reflect body_j position along reference direction
|
||||
# (or body_i if body_j is grounded)
|
||||
moving_body = obj.body_j if not obj.body_j.grounded else obj.body_i
|
||||
if moving_body.grounded:
|
||||
return None # both grounded, nothing to correct
|
||||
|
||||
px_name = f"{moving_body.part_id}/tx"
|
||||
py_name = f"{moving_body.part_id}/ty"
|
||||
pz_name = f"{moving_body.part_id}/tz"
|
||||
|
||||
sign_flip = -1.0 if moving_body is obj.body_j else 1.0
|
||||
|
||||
def correction(p: ParamTable, _val: float) -> None:
|
||||
# Reflect displacement: negate the component along reference direction
|
||||
e = p.get_env()
|
||||
cur_dx = disp_x.eval(e)
|
||||
cur_dy = disp_y.eval(e)
|
||||
cur_dz = disp_z.eval(e)
|
||||
# Project displacement onto reference direction
|
||||
proj = cur_dx * nx + cur_dy * ny + cur_dz * nz
|
||||
# Reflect: subtract 2*proj*n from the moving body's position
|
||||
if not p.is_fixed(px_name):
|
||||
p.set_value(px_name, p.get_value(px_name) + sign_flip * 2.0 * proj * nx)
|
||||
if not p.is_fixed(py_name):
|
||||
p.set_value(py_name, p.get_value(py_name) + sign_flip * 2.0 * proj * ny)
|
||||
if not p.is_fixed(pz_name):
|
||||
p.set_value(pz_name, p.get_value(pz_name) + sign_flip * 2.0 * proj * nz)
|
||||
|
||||
return HalfSpace(
|
||||
constraint_index=constraint_idx,
|
||||
reference_sign=ref_sign,
|
||||
indicator_fn=indicator,
|
||||
param_names=[px_name, py_name, pz_name],
|
||||
correction_fn=correction,
|
||||
)
|
||||
|
||||
|
||||
def _parallel_half_space(
|
||||
obj: ParallelConstraint,
|
||||
constraint_idx: int,
|
||||
env: dict[str, float],
|
||||
params: ParamTable,
|
||||
) -> HalfSpace:
|
||||
"""Half-space for Parallel: track same-direction vs opposite-direction.
|
||||
|
||||
Indicator: dot(z_i, z_j). Positive = same direction, negative = opposite.
|
||||
"""
|
||||
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
|
||||
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
|
||||
dot_expr = dot3(z_i, z_j)
|
||||
|
||||
def indicator(e: dict[str, float]) -> float:
|
||||
return dot_expr.eval(e)
|
||||
|
||||
ref_val = indicator(env)
|
||||
ref_sign = math.copysign(1.0, ref_val) if abs(ref_val) > 1e-14 else 1.0
|
||||
|
||||
# No geometric correction — just let the indicator track.
|
||||
# The Newton solver naturally handles this via the cross-product residual.
|
||||
# We only need to detect and report branch flips.
|
||||
return HalfSpace(
|
||||
constraint_index=constraint_idx,
|
||||
reference_sign=ref_sign,
|
||||
indicator_fn=indicator,
|
||||
)
|
||||
|
||||
|
||||
def _planar_half_space(
|
||||
obj: PlanarConstraint,
|
||||
constraint_idx: int,
|
||||
env: dict[str, float],
|
||||
params: ParamTable,
|
||||
) -> HalfSpace | None:
|
||||
"""Half-space for Planar: track which side of the plane the point is on
|
||||
AND which direction the normals face.
|
||||
|
||||
A Planar constraint has parallel normals (cross product = 0) plus
|
||||
point-in-plane (signed distance = 0). Both have branch ambiguity:
|
||||
the normals can be same-direction or opposite, and the point can
|
||||
approach the plane from either side. We track the signed distance
|
||||
from marker_i to the plane defined by marker_j — this captures
|
||||
the plane-side and is the primary drift mode during drag.
|
||||
"""
|
||||
# Point-in-plane signed distance as indicator
|
||||
p_i = obj.body_i.world_point(*obj.marker_i_pos)
|
||||
p_j = obj.body_j.world_point(*obj.marker_j_pos)
|
||||
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
|
||||
d_expr = point_plane_distance(p_i, p_j, z_j)
|
||||
|
||||
d_val = d_expr.eval(env)
|
||||
|
||||
# Also track normal alignment (same as Parallel half-space)
|
||||
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
|
||||
dot_expr = dot3(z_i, z_j)
|
||||
dot_val = dot_expr.eval(env)
|
||||
normal_ref_sign = math.copysign(1.0, dot_val) if abs(dot_val) > 1e-14 else 1.0
|
||||
|
||||
# If offset is zero and distance is near-zero, we still need the normal
|
||||
# direction indicator to prevent flipping through the plane.
|
||||
# Use the normal dot product as the primary indicator when the point
|
||||
# is already on the plane (distance ≈ 0).
|
||||
if abs(d_val) < 1e-10:
|
||||
# Point is on the plane — track normal direction instead
|
||||
def indicator(e: dict[str, float]) -> float:
|
||||
return dot_expr.eval(e)
|
||||
|
||||
ref_sign = normal_ref_sign
|
||||
else:
|
||||
# Point is off-plane — track which side
|
||||
def indicator(e: dict[str, float]) -> float:
|
||||
return d_expr.eval(e)
|
||||
|
||||
ref_sign = math.copysign(1.0, d_val)
|
||||
|
||||
# Correction: reflect the moving body's position through the plane
|
||||
moving_body = obj.body_j if not obj.body_j.grounded else obj.body_i
|
||||
if moving_body.grounded:
|
||||
return None
|
||||
|
||||
px_name = f"{moving_body.part_id}/tx"
|
||||
py_name = f"{moving_body.part_id}/ty"
|
||||
pz_name = f"{moving_body.part_id}/tz"
|
||||
|
||||
def correction(p: ParamTable, _val: float) -> None:
|
||||
e = p.get_env()
|
||||
# Recompute signed distance and normal direction
|
||||
d_cur = d_expr.eval(e)
|
||||
nx = z_j[0].eval(e)
|
||||
ny = z_j[1].eval(e)
|
||||
nz = z_j[2].eval(e)
|
||||
n_len = math.sqrt(nx * nx + ny * ny + nz * nz)
|
||||
if n_len < 1e-15:
|
||||
return
|
||||
nx, ny, nz = nx / n_len, ny / n_len, nz / n_len
|
||||
# Reflect through plane: move body by -2*d along normal
|
||||
sign = -1.0 if moving_body is obj.body_j else 1.0
|
||||
if not p.is_fixed(px_name):
|
||||
p.set_value(px_name, p.get_value(px_name) + sign * 2.0 * d_cur * nx)
|
||||
if not p.is_fixed(py_name):
|
||||
p.set_value(py_name, p.get_value(py_name) + sign * 2.0 * d_cur * ny)
|
||||
if not p.is_fixed(pz_name):
|
||||
p.set_value(pz_name, p.get_value(pz_name) + sign * 2.0 * d_cur * nz)
|
||||
|
||||
return HalfSpace(
|
||||
constraint_index=constraint_idx,
|
||||
reference_sign=ref_sign,
|
||||
indicator_fn=indicator,
|
||||
correction_fn=correction,
|
||||
)
|
||||
|
||||
|
||||
def _revolute_half_space(
|
||||
obj: RevoluteConstraint,
|
||||
constraint_idx: int,
|
||||
env: dict[str, float],
|
||||
params: ParamTable,
|
||||
) -> HalfSpace | None:
|
||||
"""Half-space for Revolute: track hinge axis direction.
|
||||
|
||||
A revolute has coincident origins + parallel Z-axes. The parallel
|
||||
axes can flip direction (same ambiguity as Parallel). Track
|
||||
dot(z_i, z_j) to prevent the axis from inverting.
|
||||
"""
|
||||
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
|
||||
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
|
||||
dot_expr = dot3(z_i, z_j)
|
||||
|
||||
ref_val = dot_expr.eval(env)
|
||||
ref_sign = math.copysign(1.0, ref_val) if abs(ref_val) > 1e-14 else 1.0
|
||||
|
||||
return HalfSpace(
|
||||
constraint_index=constraint_idx,
|
||||
reference_sign=ref_sign,
|
||||
indicator_fn=lambda e: dot_expr.eval(e),
|
||||
)
|
||||
|
||||
|
||||
def _concentric_half_space(
|
||||
obj: ConcentricConstraint,
|
||||
constraint_idx: int,
|
||||
env: dict[str, float],
|
||||
params: ParamTable,
|
||||
) -> HalfSpace | None:
|
||||
"""Half-space for Concentric: track axis direction.
|
||||
|
||||
Concentric has parallel axes + point-on-line. The parallel axes
|
||||
can flip direction. Track dot(z_i, z_j).
|
||||
"""
|
||||
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
|
||||
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
|
||||
dot_expr = dot3(z_i, z_j)
|
||||
|
||||
ref_val = dot_expr.eval(env)
|
||||
ref_sign = math.copysign(1.0, ref_val) if abs(ref_val) > 1e-14 else 1.0
|
||||
|
||||
return HalfSpace(
|
||||
constraint_index=constraint_idx,
|
||||
reference_sign=ref_sign,
|
||||
indicator_fn=lambda e: dot_expr.eval(e),
|
||||
)
|
||||
|
||||
|
||||
def _point_in_plane_half_space(
|
||||
obj: PointInPlaneConstraint,
|
||||
constraint_idx: int,
|
||||
env: dict[str, float],
|
||||
params: ParamTable,
|
||||
) -> HalfSpace | None:
|
||||
"""Half-space for PointInPlane: track which side of the plane.
|
||||
|
||||
The signed distance to the plane can be satisfied from either side.
|
||||
Track which side the initial configuration is on.
|
||||
"""
|
||||
p_i = obj.body_i.world_point(*obj.marker_i_pos)
|
||||
p_j = obj.body_j.world_point(*obj.marker_j_pos)
|
||||
n_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
|
||||
d_expr = point_plane_distance(p_i, p_j, n_j)
|
||||
|
||||
d_val = d_expr.eval(env)
|
||||
if abs(d_val) < 1e-10:
|
||||
return None # already on the plane, no branch to track
|
||||
|
||||
ref_sign = math.copysign(1.0, d_val)
|
||||
|
||||
moving_body = obj.body_j if not obj.body_j.grounded else obj.body_i
|
||||
if moving_body.grounded:
|
||||
return None
|
||||
|
||||
px_name = f"{moving_body.part_id}/tx"
|
||||
py_name = f"{moving_body.part_id}/ty"
|
||||
pz_name = f"{moving_body.part_id}/tz"
|
||||
|
||||
def correction(p: ParamTable, _val: float) -> None:
|
||||
e = p.get_env()
|
||||
d_cur = d_expr.eval(e)
|
||||
nx = n_j[0].eval(e)
|
||||
ny = n_j[1].eval(e)
|
||||
nz = n_j[2].eval(e)
|
||||
n_len = math.sqrt(nx * nx + ny * ny + nz * nz)
|
||||
if n_len < 1e-15:
|
||||
return
|
||||
nx, ny, nz = nx / n_len, ny / n_len, nz / n_len
|
||||
sign = -1.0 if moving_body is obj.body_j else 1.0
|
||||
if not p.is_fixed(px_name):
|
||||
p.set_value(px_name, p.get_value(px_name) + sign * 2.0 * d_cur * nx)
|
||||
if not p.is_fixed(py_name):
|
||||
p.set_value(py_name, p.get_value(py_name) + sign * 2.0 * d_cur * ny)
|
||||
if not p.is_fixed(pz_name):
|
||||
p.set_value(pz_name, p.get_value(pz_name) + sign * 2.0 * d_cur * nz)
|
||||
|
||||
return HalfSpace(
|
||||
constraint_index=constraint_idx,
|
||||
reference_sign=ref_sign,
|
||||
indicator_fn=lambda e: d_expr.eval(e),
|
||||
correction_fn=correction,
|
||||
)
|
||||
|
||||
|
||||
def _line_in_plane_half_space(
|
||||
obj: LineInPlaneConstraint,
|
||||
constraint_idx: int,
|
||||
env: dict[str, float],
|
||||
params: ParamTable,
|
||||
) -> HalfSpace | None:
|
||||
"""Half-space for LineInPlane: track which side of the plane.
|
||||
|
||||
Same plane-side ambiguity as PointInPlane.
|
||||
"""
|
||||
p_i = obj.body_i.world_point(*obj.marker_i_pos)
|
||||
p_j = obj.body_j.world_point(*obj.marker_j_pos)
|
||||
n_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
|
||||
d_expr = point_plane_distance(p_i, p_j, n_j)
|
||||
|
||||
d_val = d_expr.eval(env)
|
||||
if abs(d_val) < 1e-10:
|
||||
return None
|
||||
|
||||
ref_sign = math.copysign(1.0, d_val)
|
||||
|
||||
moving_body = obj.body_j if not obj.body_j.grounded else obj.body_i
|
||||
if moving_body.grounded:
|
||||
return None
|
||||
|
||||
px_name = f"{moving_body.part_id}/tx"
|
||||
py_name = f"{moving_body.part_id}/ty"
|
||||
pz_name = f"{moving_body.part_id}/tz"
|
||||
|
||||
def correction(p: ParamTable, _val: float) -> None:
|
||||
e = p.get_env()
|
||||
d_cur = d_expr.eval(e)
|
||||
nx = n_j[0].eval(e)
|
||||
ny = n_j[1].eval(e)
|
||||
nz = n_j[2].eval(e)
|
||||
n_len = math.sqrt(nx * nx + ny * ny + nz * nz)
|
||||
if n_len < 1e-15:
|
||||
return
|
||||
nx, ny, nz = nx / n_len, ny / n_len, nz / n_len
|
||||
sign = -1.0 if moving_body is obj.body_j else 1.0
|
||||
if not p.is_fixed(px_name):
|
||||
p.set_value(px_name, p.get_value(px_name) + sign * 2.0 * d_cur * nx)
|
||||
if not p.is_fixed(py_name):
|
||||
p.set_value(py_name, p.get_value(py_name) + sign * 2.0 * d_cur * ny)
|
||||
if not p.is_fixed(pz_name):
|
||||
p.set_value(pz_name, p.get_value(pz_name) + sign * 2.0 * d_cur * nz)
|
||||
|
||||
return HalfSpace(
|
||||
constraint_index=constraint_idx,
|
||||
reference_sign=ref_sign,
|
||||
indicator_fn=lambda e: d_expr.eval(e),
|
||||
correction_fn=correction,
|
||||
)
|
||||
|
||||
|
||||
def _axis_direction_half_space(
|
||||
obj,
|
||||
constraint_idx: int,
|
||||
env: dict[str, float],
|
||||
) -> HalfSpace | None:
|
||||
"""Half-space for any constraint with parallel Z-axes (Cylindrical, Slider, Screw).
|
||||
|
||||
Tracks dot(z_i, z_j) to prevent axis inversion.
|
||||
"""
|
||||
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
|
||||
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
|
||||
dot_expr = dot3(z_i, z_j)
|
||||
|
||||
ref_val = dot_expr.eval(env)
|
||||
ref_sign = math.copysign(1.0, ref_val) if abs(ref_val) > 1e-14 else 1.0
|
||||
|
||||
return HalfSpace(
|
||||
constraint_index=constraint_idx,
|
||||
reference_sign=ref_sign,
|
||||
indicator_fn=lambda e: dot_expr.eval(e),
|
||||
)
|
||||
|
||||
|
||||
def _universal_half_space(
|
||||
obj: UniversalConstraint,
|
||||
constraint_idx: int,
|
||||
env: dict[str, float],
|
||||
) -> HalfSpace | None:
|
||||
"""Half-space for Universal: track which quadrant of perpendicularity.
|
||||
|
||||
Universal has dot(z_i, z_j) = 0 (perpendicular). The cross product
|
||||
sign distinguishes which "side" of perpendicular.
|
||||
"""
|
||||
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
|
||||
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
|
||||
cx, cy, cz = cross3(z_i, z_j)
|
||||
|
||||
cx_val = cx.eval(env)
|
||||
cy_val = cy.eval(env)
|
||||
cz_val = cz.eval(env)
|
||||
|
||||
components = [
|
||||
(abs(cx_val), cx, cx_val),
|
||||
(abs(cy_val), cy, cy_val),
|
||||
(abs(cz_val), cz, cz_val),
|
||||
]
|
||||
_, best_expr, best_val = max(components, key=lambda t: t[0])
|
||||
|
||||
if abs(best_val) < 1e-14:
|
||||
return None
|
||||
|
||||
ref_sign = math.copysign(1.0, best_val)
|
||||
|
||||
return HalfSpace(
|
||||
constraint_index=constraint_idx,
|
||||
reference_sign=ref_sign,
|
||||
indicator_fn=lambda e: best_expr.eval(e),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Minimum-movement weighting
|
||||
# ============================================================================
|
||||
|
||||
# Scale factor so that a 1-radian rotation is penalised as much as a
|
||||
# (180/pi)-unit translation. This makes the weighted minimum-norm
|
||||
# step prefer translating over rotating for the same residual reduction.
|
||||
QUAT_WEIGHT = (180.0 / math.pi) ** 2 # ~3283
|
||||
|
||||
|
||||
def build_weight_vector(params: ParamTable) -> np.ndarray:
|
||||
"""Build diagonal weight vector: 1.0 for translation, QUAT_WEIGHT for quaternion.
|
||||
|
||||
Returns a 1-D array of length ``params.n_free()``.
|
||||
"""
|
||||
free = params.free_names()
|
||||
w = np.ones(len(free))
|
||||
quat_suffixes = ("/qw", "/qx", "/qy", "/qz")
|
||||
for i, name in enumerate(free):
|
||||
if any(name.endswith(s) for s in quat_suffixes):
|
||||
w[i] = QUAT_WEIGHT
|
||||
return w
|
||||
|
||||
|
||||
def _angle_half_space(
|
||||
obj: AngleConstraint,
|
||||
constraint_idx: int,
|
||||
env: dict[str, float],
|
||||
params: ParamTable,
|
||||
) -> HalfSpace | None:
|
||||
"""Half-space for Angle: track sign of sin(angle) via cross product.
|
||||
|
||||
For angle constraints, the dot product is fixed (= cos(angle)),
|
||||
but sin can be +/-. We track the cross product magnitude sign.
|
||||
"""
|
||||
if abs(obj.angle) < 1e-14 or abs(obj.angle - math.pi) < 1e-14:
|
||||
return None # 0 or 180 degrees — no branch ambiguity
|
||||
|
||||
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
|
||||
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
|
||||
cx, cy, cz = cross3(z_i, z_j)
|
||||
|
||||
# Use the magnitude of the cross product's z-component as indicator
|
||||
# (or whichever component is largest at setup time)
|
||||
cx_val = cx.eval(env)
|
||||
cy_val = cy.eval(env)
|
||||
cz_val = cz.eval(env)
|
||||
|
||||
# Pick the dominant cross product component
|
||||
components = [
|
||||
(abs(cx_val), cx, cx_val),
|
||||
(abs(cy_val), cy, cy_val),
|
||||
(abs(cz_val), cz, cz_val),
|
||||
]
|
||||
_, best_expr, best_val = max(components, key=lambda t: t[0])
|
||||
|
||||
if abs(best_val) < 1e-14:
|
||||
return None
|
||||
|
||||
def indicator(e: dict[str, float]) -> float:
|
||||
return best_expr.eval(e)
|
||||
|
||||
ref_sign = math.copysign(1.0, best_val)
|
||||
|
||||
return HalfSpace(
|
||||
constraint_index=constraint_idx,
|
||||
reference_sign=ref_sign,
|
||||
indicator_fn=indicator,
|
||||
)
|
||||
|
||||
|
||||
def _perpendicular_half_space(
|
||||
obj: PerpendicularConstraint,
|
||||
constraint_idx: int,
|
||||
env: dict[str, float],
|
||||
params: ParamTable,
|
||||
) -> HalfSpace | None:
|
||||
"""Half-space for Perpendicular: track which quadrant.
|
||||
|
||||
The dot product is constrained to 0, but the cross product sign
|
||||
distinguishes which "side" of perpendicular.
|
||||
"""
|
||||
z_i = marker_z_axis(obj.body_i, obj.marker_i_quat)
|
||||
z_j = marker_z_axis(obj.body_j, obj.marker_j_quat)
|
||||
cx, cy, cz = cross3(z_i, z_j)
|
||||
|
||||
# Pick the dominant cross product component
|
||||
cx_val = cx.eval(env)
|
||||
cy_val = cy.eval(env)
|
||||
cz_val = cz.eval(env)
|
||||
|
||||
components = [
|
||||
(abs(cx_val), cx, cx_val),
|
||||
(abs(cy_val), cy, cy_val),
|
||||
(abs(cz_val), cz, cz_val),
|
||||
]
|
||||
_, best_expr, best_val = max(components, key=lambda t: t[0])
|
||||
|
||||
if abs(best_val) < 1e-14:
|
||||
return None
|
||||
|
||||
def indicator(e: dict[str, float]) -> float:
|
||||
return best_expr.eval(e)
|
||||
|
||||
ref_sign = math.copysign(1.0, best_val)
|
||||
|
||||
return HalfSpace(
|
||||
constraint_index=constraint_idx,
|
||||
reference_sign=ref_sign,
|
||||
indicator_fn=indicator,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
355
tests/console_test_phase5.py
Normal file
355
tests/console_test_phase5.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""
|
||||
Phase 5 in-client console tests.
|
||||
|
||||
Paste into the FreeCAD Python console (or run via: exec(open(...).read())).
|
||||
Tests the full Assembly -> KindredSolver pipeline without the unittest harness.
|
||||
|
||||
Expected output: all lines print PASS. Any FAIL indicates a regression.
|
||||
"""
|
||||
|
||||
import FreeCAD as App
|
||||
import JointObject
|
||||
import kcsolve
|
||||
|
||||
_pref = App.ParamGet("User parameter:BaseApp/Preferences/Mod/Assembly")
|
||||
_orig_solver = _pref.GetString("Solver", "")
|
||||
|
||||
_results = []
|
||||
|
||||
|
||||
def _report(name, passed, detail=""):
|
||||
status = "PASS" if passed else "FAIL"
|
||||
msg = f" [{status}] {name}"
|
||||
if detail:
|
||||
msg += f" — {detail}"
|
||||
print(msg)
|
||||
_results.append((name, passed))
|
||||
|
||||
|
||||
def _new_doc(name="Phase5Test"):
|
||||
if App.ActiveDocument and App.ActiveDocument.Name == name:
|
||||
App.closeDocument(name)
|
||||
App.newDocument(name)
|
||||
App.setActiveDocument(name)
|
||||
return App.ActiveDocument
|
||||
|
||||
|
||||
def _cleanup(doc):
|
||||
App.closeDocument(doc.Name)
|
||||
|
||||
|
||||
def _make_assembly(doc):
|
||||
asm = doc.addObject("Assembly::AssemblyObject", "Assembly")
|
||||
asm.resetSolver()
|
||||
jg = asm.newObject("Assembly::JointGroup", "Joints")
|
||||
return asm, jg
|
||||
|
||||
|
||||
def _make_box(asm, x=0, y=0, z=0, size=10):
|
||||
box = asm.newObject("Part::Box", "Box")
|
||||
box.Length = size
|
||||
box.Width = size
|
||||
box.Height = size
|
||||
box.Placement = App.Placement(App.Vector(x, y, z), App.Rotation())
|
||||
return box
|
||||
|
||||
|
||||
def _ground(jg, obj):
|
||||
gnd = jg.newObject("App::FeaturePython", "GroundedJoint")
|
||||
JointObject.GroundedJoint(gnd, obj)
|
||||
return gnd
|
||||
|
||||
|
||||
def _make_joint(jg, joint_type, ref1, ref2):
|
||||
joint = jg.newObject("App::FeaturePython", "Joint")
|
||||
JointObject.Joint(joint, joint_type)
|
||||
refs = [[ref1[0], ref1[1]], [ref2[0], ref2[1]]]
|
||||
joint.Proxy.setJointConnectors(joint, refs)
|
||||
return joint
|
||||
|
||||
|
||||
# ── Test 1: Registry ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_solver_registry():
|
||||
"""Verify kindred solver is registered and available."""
|
||||
names = kcsolve.available()
|
||||
_report(
|
||||
"registry: kindred in available()", "kindred" in names, f"available={names}"
|
||||
)
|
||||
|
||||
solver = kcsolve.load("kindred")
|
||||
_report("registry: load('kindred') succeeds", solver is not None)
|
||||
_report(
|
||||
"registry: solver name",
|
||||
solver.name() == "Kindred (Newton-Raphson)",
|
||||
f"got '{solver.name()}'",
|
||||
)
|
||||
|
||||
joints = solver.supported_joints()
|
||||
_report(
|
||||
"registry: supported_joints non-empty",
|
||||
len(joints) > 0,
|
||||
f"{len(joints)} joint types",
|
||||
)
|
||||
|
||||
|
||||
# ── Test 2: Preference switching ────────────────────────────────────
|
||||
|
||||
|
||||
def test_preference_switching():
|
||||
"""Verify solver preference controls which backend is used."""
|
||||
doc = _new_doc("PrefTest")
|
||||
try:
|
||||
# Set to kindred
|
||||
_pref.SetString("Solver", "kindred")
|
||||
asm, jg = _make_assembly(doc)
|
||||
|
||||
box1 = _make_box(asm, 0, 0, 0)
|
||||
box2 = _make_box(asm, 50, 0, 0)
|
||||
_ground(box1)
|
||||
_make_joint(jg, 0, [box1, ["Face6", "Vertex7"]], [box2, ["Face6", "Vertex7"]])
|
||||
|
||||
result = asm.solve()
|
||||
_report(
|
||||
"pref: kindred solve succeeds", result == 0, f"solve() returned {result}"
|
||||
)
|
||||
|
||||
# Switch back to ondsel
|
||||
_pref.SetString("Solver", "ondsel")
|
||||
asm.resetSolver()
|
||||
result2 = asm.solve()
|
||||
_report(
|
||||
"pref: ondsel solve succeeds after switch",
|
||||
result2 == 0,
|
||||
f"solve() returned {result2}",
|
||||
)
|
||||
finally:
|
||||
_cleanup(doc)
|
||||
|
||||
|
||||
# ── Test 3: Fixed joint ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_fixed_joint():
|
||||
"""Two boxes + ground + fixed joint -> placements match."""
|
||||
_pref.SetString("Solver", "kindred")
|
||||
doc = _new_doc("FixedTest")
|
||||
try:
|
||||
asm, jg = _make_assembly(doc)
|
||||
box1 = _make_box(asm, 10, 20, 30)
|
||||
box2 = _make_box(asm, 40, 50, 60)
|
||||
_ground(box2)
|
||||
_make_joint(jg, 0, [box2, ["Face6", "Vertex7"]], [box1, ["Face6", "Vertex7"]])
|
||||
|
||||
same = box1.Placement.isSame(box2.Placement, 1e-6)
|
||||
_report(
|
||||
"fixed: box1 matches box2 placement",
|
||||
same,
|
||||
f"box1={box1.Placement.Base}, box2={box2.Placement.Base}",
|
||||
)
|
||||
finally:
|
||||
_cleanup(doc)
|
||||
|
||||
|
||||
# ── Test 4: Revolute joint + DOF ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_revolute_dof():
|
||||
"""Revolute joint -> solve succeeds, DOF = 1."""
|
||||
_pref.SetString("Solver", "kindred")
|
||||
doc = _new_doc("RevoluteTest")
|
||||
try:
|
||||
asm, jg = _make_assembly(doc)
|
||||
box1 = _make_box(asm, 0, 0, 0)
|
||||
box2 = _make_box(asm, 100, 0, 0)
|
||||
_ground(box1)
|
||||
_make_joint(jg, 1, [box1, ["Face6", "Vertex7"]], [box2, ["Face6", "Vertex7"]])
|
||||
|
||||
result = asm.solve()
|
||||
_report("revolute: solve succeeds", result == 0, f"solve() returned {result}")
|
||||
|
||||
dof = asm.getLastDoF()
|
||||
_report("revolute: DOF = 1", dof == 1, f"DOF = {dof}")
|
||||
finally:
|
||||
_cleanup(doc)
|
||||
|
||||
|
||||
# ── Test 5: No ground ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_no_ground():
|
||||
"""No grounded parts -> returns -6."""
|
||||
_pref.SetString("Solver", "kindred")
|
||||
doc = _new_doc("NoGroundTest")
|
||||
try:
|
||||
asm, jg = _make_assembly(doc)
|
||||
box1 = _make_box(asm, 0, 0, 0)
|
||||
box2 = _make_box(asm, 50, 0, 0)
|
||||
|
||||
joint = jg.newObject("App::FeaturePython", "Joint")
|
||||
JointObject.Joint(joint, 0)
|
||||
refs = [[box1, ["Face6", "Vertex7"]], [box2, ["Face6", "Vertex7"]]]
|
||||
joint.Proxy.setJointConnectors(joint, refs)
|
||||
|
||||
result = asm.solve()
|
||||
_report("no-ground: returns -6", result == -6, f"solve() returned {result}")
|
||||
finally:
|
||||
_cleanup(doc)
|
||||
|
||||
|
||||
# ── Test 6: Solve stability ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_stability():
|
||||
"""Solving twice gives identical placements."""
|
||||
_pref.SetString("Solver", "kindred")
|
||||
doc = _new_doc("StabilityTest")
|
||||
try:
|
||||
asm, jg = _make_assembly(doc)
|
||||
box1 = _make_box(asm, 10, 20, 30)
|
||||
box2 = _make_box(asm, 40, 50, 60)
|
||||
_ground(box2)
|
||||
_make_joint(jg, 0, [box2, ["Face6", "Vertex7"]], [box1, ["Face6", "Vertex7"]])
|
||||
|
||||
asm.solve()
|
||||
plc1 = App.Placement(box1.Placement)
|
||||
asm.solve()
|
||||
plc2 = box1.Placement
|
||||
|
||||
same = plc1.isSame(plc2, 1e-6)
|
||||
_report("stability: two solves identical", same)
|
||||
finally:
|
||||
_cleanup(doc)
|
||||
|
||||
|
||||
# ── Test 7: Standalone solver API ────────────────────────────────────
|
||||
|
||||
|
||||
def test_standalone_api():
|
||||
"""Use kcsolve types directly without FreeCAD Assembly objects."""
|
||||
solver = kcsolve.load("kindred")
|
||||
|
||||
# Two parts: one grounded, one free
|
||||
p1 = kcsolve.Part()
|
||||
p1.id = "base"
|
||||
p1.placement = kcsolve.Transform.identity()
|
||||
p1.grounded = True
|
||||
|
||||
p2 = kcsolve.Part()
|
||||
p2.id = "arm"
|
||||
p2.placement = kcsolve.Transform()
|
||||
p2.placement.position = [100.0, 0.0, 0.0]
|
||||
p2.placement.quaternion = [1.0, 0.0, 0.0, 0.0]
|
||||
p2.grounded = False
|
||||
|
||||
# Fixed joint
|
||||
c = kcsolve.Constraint()
|
||||
c.id = "fix1"
|
||||
c.part_i = "base"
|
||||
c.marker_i = kcsolve.Transform.identity()
|
||||
c.part_j = "arm"
|
||||
c.marker_j = kcsolve.Transform.identity()
|
||||
c.type = kcsolve.BaseJointKind.Fixed
|
||||
|
||||
ctx = kcsolve.SolveContext()
|
||||
ctx.parts = [p1, p2]
|
||||
ctx.constraints = [c]
|
||||
|
||||
result = solver.solve(ctx)
|
||||
_report(
|
||||
"standalone: solve status",
|
||||
result.status == kcsolve.SolveStatus.Success,
|
||||
f"status={result.status}",
|
||||
)
|
||||
_report("standalone: DOF = 0", result.dof == 0, f"dof={result.dof}")
|
||||
|
||||
# Check that arm moved to origin
|
||||
for pr in result.placements:
|
||||
if pr.id == "arm":
|
||||
dist = sum(x**2 for x in pr.placement.position) ** 0.5
|
||||
_report("standalone: arm at origin", dist < 1e-6, f"distance={dist:.2e}")
|
||||
break
|
||||
|
||||
|
||||
# ── Test 8: Diagnose API ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_diagnose():
|
||||
"""Diagnose overconstrained system via standalone API."""
|
||||
solver = kcsolve.load("kindred")
|
||||
|
||||
p1 = kcsolve.Part()
|
||||
p1.id = "base"
|
||||
p1.placement = kcsolve.Transform.identity()
|
||||
p1.grounded = True
|
||||
|
||||
p2 = kcsolve.Part()
|
||||
p2.id = "arm"
|
||||
p2.placement = kcsolve.Transform()
|
||||
p2.placement.position = [50.0, 0.0, 0.0]
|
||||
p2.placement.quaternion = [1.0, 0.0, 0.0, 0.0]
|
||||
p2.grounded = False
|
||||
|
||||
# Two fixed joints = overconstrained
|
||||
c1 = kcsolve.Constraint()
|
||||
c1.id = "fix1"
|
||||
c1.part_i = "base"
|
||||
c1.marker_i = kcsolve.Transform.identity()
|
||||
c1.part_j = "arm"
|
||||
c1.marker_j = kcsolve.Transform.identity()
|
||||
c1.type = kcsolve.BaseJointKind.Fixed
|
||||
|
||||
c2 = kcsolve.Constraint()
|
||||
c2.id = "fix2"
|
||||
c2.part_i = "base"
|
||||
c2.marker_i = kcsolve.Transform.identity()
|
||||
c2.part_j = "arm"
|
||||
c2.marker_j = kcsolve.Transform.identity()
|
||||
c2.type = kcsolve.BaseJointKind.Fixed
|
||||
|
||||
ctx = kcsolve.SolveContext()
|
||||
ctx.parts = [p1, p2]
|
||||
ctx.constraints = [c1, c2]
|
||||
|
||||
diags = solver.diagnose(ctx)
|
||||
_report(
|
||||
"diagnose: returns diagnostics", len(diags) > 0, f"{len(diags)} diagnostic(s)"
|
||||
)
|
||||
if diags:
|
||||
kinds = [d.kind for d in diags]
|
||||
_report(
|
||||
"diagnose: found redundant",
|
||||
kcsolve.DiagnosticKind.Redundant in kinds,
|
||||
f"kinds={[str(k) for k in kinds]}",
|
||||
)
|
||||
|
||||
|
||||
# ── Run all ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def run_all():
|
||||
print("\n=== Phase 5 Console Tests ===\n")
|
||||
|
||||
test_solver_registry()
|
||||
test_preference_switching()
|
||||
test_fixed_joint()
|
||||
test_revolute_dof()
|
||||
test_no_ground()
|
||||
test_stability()
|
||||
test_standalone_api()
|
||||
test_diagnose()
|
||||
|
||||
# Restore original preference
|
||||
_pref.SetString("Solver", _orig_solver)
|
||||
|
||||
# Summary
|
||||
passed = sum(1 for _, p in _results if p)
|
||||
total = len(_results)
|
||||
print(f"\n=== {passed}/{total} passed ===\n")
|
||||
if passed < total:
|
||||
failed = [name for name, p in _results if not p]
|
||||
print(f"FAILED: {', '.join(failed)}")
|
||||
|
||||
|
||||
run_all()
|
||||
70
tests/test_bfgs.py
Normal file
70
tests/test_bfgs.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Tests for the BFGS fallback solver."""
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
from kindred_solver.bfgs import bfgs_solve
|
||||
from kindred_solver.expr import Const, Var
|
||||
from kindred_solver.params import ParamTable
|
||||
|
||||
|
||||
class TestBFGSBasic:
|
||||
def test_single_linear(self):
|
||||
"""Solve x - 3 = 0."""
|
||||
pt = ParamTable()
|
||||
x = pt.add("x", 0.0)
|
||||
assert bfgs_solve([x - Const(3.0)], pt) is True
|
||||
assert abs(pt.get_value("x") - 3.0) < 1e-8
|
||||
|
||||
def test_single_quadratic(self):
|
||||
"""Solve x^2 - 4 = 0 from x=1 → x=2."""
|
||||
pt = ParamTable()
|
||||
x = pt.add("x", 1.0)
|
||||
assert bfgs_solve([x * x - Const(4.0)], pt) is True
|
||||
assert abs(pt.get_value("x") - 2.0) < 1e-8
|
||||
|
||||
def test_two_variables(self):
|
||||
"""Solve x + y = 5, x - y = 1."""
|
||||
pt = ParamTable()
|
||||
x = pt.add("x", 0.0)
|
||||
y = pt.add("y", 0.0)
|
||||
assert bfgs_solve([x + y - Const(5.0), x - y - Const(1.0)], pt) is True
|
||||
assert abs(pt.get_value("x") - 3.0) < 1e-8
|
||||
assert abs(pt.get_value("y") - 2.0) < 1e-8
|
||||
|
||||
def test_empty_system(self):
|
||||
pt = ParamTable()
|
||||
assert bfgs_solve([], pt) is True
|
||||
|
||||
def test_with_quat_renorm(self):
|
||||
"""Quaternion re-normalization during BFGS."""
|
||||
pt = ParamTable()
|
||||
qw = pt.add("qw", 0.9)
|
||||
qx = pt.add("qx", 0.1)
|
||||
qy = pt.add("qy", 0.1)
|
||||
qz = pt.add("qz", 0.1)
|
||||
r = qw * qw + qx * qx + qy * qy + qz * qz - Const(1.0)
|
||||
groups = [("qw", "qx", "qy", "qz")]
|
||||
assert bfgs_solve([r], pt, quat_groups=groups) is True
|
||||
w, x, y, z = (pt.get_value(n) for n in ["qw", "qx", "qy", "qz"])
|
||||
norm = math.sqrt(w**2 + x**2 + y**2 + z**2)
|
||||
assert abs(norm - 1.0) < 1e-8
|
||||
|
||||
|
||||
class TestBFGSGeometric:
|
||||
def test_distance_constraint(self):
|
||||
"""x^2 - 25 = 0 from x=3 → x=5."""
|
||||
pt = ParamTable()
|
||||
x = pt.add("x", 3.0)
|
||||
assert bfgs_solve([x * x - Const(25.0)], pt) is True
|
||||
assert abs(pt.get_value("x") - 5.0) < 1e-8
|
||||
|
||||
def test_difficult_initial_guess(self):
|
||||
"""BFGS should handle worse initial guesses than Newton."""
|
||||
pt = ParamTable()
|
||||
x = pt.add("x", 100.0)
|
||||
y = pt.add("y", -50.0)
|
||||
residuals = [x + y - Const(5.0), x - y - Const(1.0)]
|
||||
assert bfgs_solve(residuals, pt) is True
|
||||
assert abs(pt.get_value("x") - 3.0) < 1e-6
|
||||
assert abs(pt.get_value("y") - 2.0) < 1e-6
|
||||
357
tests/test_codegen.py
Normal file
357
tests/test_codegen.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""Tests for the codegen module — CSE, compilation, and compiled evaluation."""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from kindred_solver.codegen import (
|
||||
_build_cse,
|
||||
_find_nonzero_entries,
|
||||
compile_system,
|
||||
try_compile_system,
|
||||
)
|
||||
from kindred_solver.expr import (
|
||||
ZERO,
|
||||
Add,
|
||||
Const,
|
||||
Cos,
|
||||
Div,
|
||||
Mul,
|
||||
Neg,
|
||||
Pow,
|
||||
Sin,
|
||||
Sqrt,
|
||||
Sub,
|
||||
Var,
|
||||
)
|
||||
from kindred_solver.newton import newton_solve
|
||||
from kindred_solver.params import ParamTable
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# to_code() — round-trip correctness for each Expr type
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToCode:
|
||||
"""Verify that eval(expr.to_code()) == expr.eval(env) for each node."""
|
||||
|
||||
NS = {"_sin": math.sin, "_cos": math.cos, "_sqrt": math.sqrt}
|
||||
|
||||
def _check(self, expr, env):
|
||||
code = expr.to_code()
|
||||
ns = dict(self.NS)
|
||||
ns["env"] = env
|
||||
compiled = eval(code, ns)
|
||||
expected = expr.eval(env)
|
||||
assert abs(compiled - expected) < 1e-15, (
|
||||
f"{code} = {compiled}, expected {expected}"
|
||||
)
|
||||
|
||||
def test_const(self):
|
||||
self._check(Const(3.14), {})
|
||||
|
||||
def test_const_negative(self):
|
||||
self._check(Const(-2.5), {})
|
||||
|
||||
def test_const_zero(self):
|
||||
self._check(Const(0.0), {})
|
||||
|
||||
def test_var(self):
|
||||
self._check(Var("x"), {"x": 7.0})
|
||||
|
||||
def test_neg(self):
|
||||
self._check(Neg(Var("x")), {"x": 3.0})
|
||||
|
||||
def test_add(self):
|
||||
self._check(Add(Var("x"), Const(2.0)), {"x": 5.0})
|
||||
|
||||
def test_sub(self):
|
||||
self._check(Sub(Var("x"), Var("y")), {"x": 5.0, "y": 3.0})
|
||||
|
||||
def test_mul(self):
|
||||
self._check(Mul(Var("x"), Const(3.0)), {"x": 4.0})
|
||||
|
||||
def test_div(self):
|
||||
self._check(Div(Var("x"), Const(2.0)), {"x": 6.0})
|
||||
|
||||
def test_pow(self):
|
||||
self._check(Pow(Var("x"), Const(3.0)), {"x": 2.0})
|
||||
|
||||
def test_sin(self):
|
||||
self._check(Sin(Var("x")), {"x": 1.0})
|
||||
|
||||
def test_cos(self):
|
||||
self._check(Cos(Var("x")), {"x": 1.0})
|
||||
|
||||
def test_sqrt(self):
|
||||
self._check(Sqrt(Var("x")), {"x": 9.0})
|
||||
|
||||
def test_nested(self):
|
||||
"""Complex nested expression."""
|
||||
x, y = Var("x"), Var("y")
|
||||
expr = Add(Mul(Sin(x), Cos(y)), Sqrt(Sub(x, Neg(y))))
|
||||
self._check(expr, {"x": 2.0, "y": 1.0})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CSE
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCSE:
|
||||
def test_no_sharing(self):
|
||||
"""Distinct expressions produce no CSE temps."""
|
||||
a = Var("x") + Const(1.0)
|
||||
b = Var("y") + Const(2.0)
|
||||
id_to_temp, temps = _build_cse([a, b])
|
||||
assert len(temps) == 0
|
||||
|
||||
def test_shared_subtree(self):
|
||||
"""Same node object used in two places is extracted."""
|
||||
x = Var("x")
|
||||
shared = x * Const(2.0) # single Mul node
|
||||
a = shared + Const(1.0)
|
||||
b = shared + Const(3.0)
|
||||
id_to_temp, temps = _build_cse([a, b])
|
||||
assert len(temps) >= 1
|
||||
# The shared Mul node should be a temp
|
||||
assert id(shared) in id_to_temp
|
||||
|
||||
def test_leaf_nodes_not_extracted(self):
|
||||
"""Const and Var nodes are never extracted as temps."""
|
||||
x = Var("x")
|
||||
c = Const(5.0)
|
||||
a = x + c
|
||||
b = x + c
|
||||
id_to_temp, temps = _build_cse([a, b])
|
||||
for _, expr in temps:
|
||||
assert not isinstance(expr, (Const, Var))
|
||||
|
||||
def test_dependency_order(self):
|
||||
"""Temps are in dependency order (dependencies first)."""
|
||||
x = Var("x")
|
||||
inner = x * Const(2.0)
|
||||
outer = inner + inner # uses inner twice
|
||||
wrapper_a = outer * Const(3.0)
|
||||
wrapper_b = outer * Const(4.0)
|
||||
id_to_temp, temps = _build_cse([wrapper_a, wrapper_b])
|
||||
# If both inner and outer are temps, inner must come first
|
||||
temp_names = [name for name, _ in temps]
|
||||
temp_ids = [id(expr) for _, expr in temps]
|
||||
if id(inner) in set(id_to_temp) and id(outer) in set(id_to_temp):
|
||||
inner_idx = temp_ids.index(id(inner))
|
||||
outer_idx = temp_ids.index(id(outer))
|
||||
assert inner_idx < outer_idx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sparsity detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSparsity:
|
||||
def test_zero_entries_skipped(self):
|
||||
nz = _find_nonzero_entries(
|
||||
[
|
||||
[Const(0.0), Var("x"), Const(0.0)],
|
||||
[Const(1.0), Const(0.0), Var("y")],
|
||||
]
|
||||
)
|
||||
assert nz == [(0, 1), (1, 0), (1, 2)]
|
||||
|
||||
def test_all_nonzero(self):
|
||||
nz = _find_nonzero_entries(
|
||||
[
|
||||
[Var("x"), Const(1.0)],
|
||||
]
|
||||
)
|
||||
assert nz == [(0, 0), (0, 1)]
|
||||
|
||||
def test_all_zero(self):
|
||||
nz = _find_nonzero_entries(
|
||||
[
|
||||
[Const(0.0), Const(0.0)],
|
||||
]
|
||||
)
|
||||
assert nz == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full compilation pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompileSystem:
|
||||
def test_simple_linear(self):
|
||||
"""Compile and evaluate a trivial system: r = x - 3, J = [[1]]."""
|
||||
x = Var("x")
|
||||
residuals = [x - Const(3.0)]
|
||||
jac_exprs = [[Const(1.0)]] # d(x-3)/dx = 1
|
||||
|
||||
fn = compile_system(residuals, jac_exprs, 1, 1)
|
||||
|
||||
env = {"x": 5.0}
|
||||
r_vec = np.empty(1)
|
||||
J = np.zeros((1, 1))
|
||||
fn(env, r_vec, J)
|
||||
|
||||
assert abs(r_vec[0] - 2.0) < 1e-15 # 5 - 3 = 2
|
||||
assert abs(J[0, 0] - 1.0) < 1e-15
|
||||
|
||||
def test_two_variable_system(self):
|
||||
"""Compile: r0 = x + y - 5, r1 = x - y - 1."""
|
||||
x, y = Var("x"), Var("y")
|
||||
residuals = [x + y - Const(5.0), x - y - Const(1.0)]
|
||||
jac_exprs = [
|
||||
[Const(1.0), Const(1.0)], # d(r0)/dx, d(r0)/dy
|
||||
[Const(1.0), Const(-1.0)], # d(r1)/dx, d(r1)/dy
|
||||
]
|
||||
|
||||
fn = compile_system(residuals, jac_exprs, 2, 2)
|
||||
|
||||
env = {"x": 3.0, "y": 2.0}
|
||||
r_vec = np.empty(2)
|
||||
J = np.zeros((2, 2))
|
||||
fn(env, r_vec, J)
|
||||
|
||||
assert abs(r_vec[0] - 0.0) < 1e-15
|
||||
assert abs(r_vec[1] - 0.0) < 1e-15
|
||||
assert abs(J[0, 0] - 1.0) < 1e-15
|
||||
assert abs(J[0, 1] - 1.0) < 1e-15
|
||||
assert abs(J[1, 0] - 1.0) < 1e-15
|
||||
assert abs(J[1, 1] - (-1.0)) < 1e-15
|
||||
|
||||
def test_sparse_jacobian(self):
|
||||
"""Zero Jacobian entries remain zero after compiled evaluation."""
|
||||
x = Var("x")
|
||||
y = Var("y")
|
||||
# r0 depends on x only, r1 depends on y only
|
||||
residuals = [x - Const(1.0), y - Const(2.0)]
|
||||
jac_exprs = [
|
||||
[Const(1.0), Const(0.0)],
|
||||
[Const(0.0), Const(1.0)],
|
||||
]
|
||||
|
||||
fn = compile_system(residuals, jac_exprs, 2, 2)
|
||||
|
||||
env = {"x": 3.0, "y": 4.0}
|
||||
r_vec = np.empty(2)
|
||||
J = np.zeros((2, 2))
|
||||
fn(env, r_vec, J)
|
||||
|
||||
assert abs(J[0, 1]) < 1e-15 # should remain zero
|
||||
assert abs(J[1, 0]) < 1e-15 # should remain zero
|
||||
assert abs(J[0, 0] - 1.0) < 1e-15
|
||||
assert abs(J[1, 1] - 1.0) < 1e-15
|
||||
|
||||
def test_trig_functions(self):
|
||||
"""Compiled evaluation handles Sin/Cos/Sqrt."""
|
||||
x = Var("x")
|
||||
residuals = [Sin(x), Cos(x), Sqrt(x)]
|
||||
jac_exprs = [
|
||||
[Cos(x)],
|
||||
[Neg(Sin(x))],
|
||||
[Div(Const(1.0), Mul(Const(2.0), Sqrt(x)))],
|
||||
]
|
||||
|
||||
fn = compile_system(residuals, jac_exprs, 3, 1)
|
||||
|
||||
env = {"x": 1.0}
|
||||
r_vec = np.empty(3)
|
||||
J = np.zeros((3, 1))
|
||||
fn(env, r_vec, J)
|
||||
|
||||
assert abs(r_vec[0] - math.sin(1.0)) < 1e-15
|
||||
assert abs(r_vec[1] - math.cos(1.0)) < 1e-15
|
||||
assert abs(r_vec[2] - math.sqrt(1.0)) < 1e-15
|
||||
assert abs(J[0, 0] - math.cos(1.0)) < 1e-15
|
||||
assert abs(J[1, 0] - (-math.sin(1.0))) < 1e-15
|
||||
assert abs(J[2, 0] - (1.0 / (2.0 * math.sqrt(1.0)))) < 1e-15
|
||||
|
||||
def test_matches_tree_walk(self):
|
||||
"""Compiled eval produces identical results to tree-walk eval."""
|
||||
pt = ParamTable()
|
||||
x = pt.add("x", 2.0)
|
||||
y = pt.add("y", 3.0)
|
||||
|
||||
residuals = [x * y - Const(6.0), x * x + y - Const(7.0)]
|
||||
free = pt.free_names()
|
||||
|
||||
jac_exprs = [[r.diff(name).simplify() for name in free] for r in residuals]
|
||||
|
||||
fn = compile_system(residuals, jac_exprs, 2, 2)
|
||||
|
||||
# Tree-walk eval
|
||||
env = pt.get_env()
|
||||
r_tree = np.array([r.eval(env) for r in residuals])
|
||||
J_tree = np.empty((2, 2))
|
||||
for i in range(2):
|
||||
for j in range(2):
|
||||
J_tree[i, j] = jac_exprs[i][j].eval(env)
|
||||
|
||||
# Compiled eval
|
||||
r_comp = np.empty(2)
|
||||
J_comp = np.zeros((2, 2))
|
||||
fn(pt.env_ref(), r_comp, J_comp)
|
||||
|
||||
np.testing.assert_allclose(r_comp, r_tree, atol=1e-15)
|
||||
np.testing.assert_allclose(J_comp, J_tree, atol=1e-15)
|
||||
|
||||
|
||||
class TestTryCompile:
|
||||
def test_returns_callable(self):
|
||||
x = Var("x")
|
||||
fn = try_compile_system([x], [[Const(1.0)]], 1, 1)
|
||||
assert fn is not None
|
||||
|
||||
def test_empty_system(self):
|
||||
"""Empty system returns None (nothing to compile)."""
|
||||
fn = try_compile_system([], [], 0, 0)
|
||||
# Empty system is handled by the solver before codegen is reached,
|
||||
# so returning None is acceptable.
|
||||
assert fn is None or callable(fn)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: Newton with compiled eval matches tree-walk
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompiledNewton:
|
||||
def test_single_linear(self):
|
||||
"""Solve x - 3 = 0 with compiled eval."""
|
||||
pt = ParamTable()
|
||||
x = pt.add("x", 0.0)
|
||||
residuals = [x - Const(3.0)]
|
||||
assert newton_solve(residuals, pt) is True
|
||||
assert abs(pt.get_value("x") - 3.0) < 1e-10
|
||||
|
||||
def test_two_variables(self):
|
||||
"""Solve x + y = 5, x - y = 1 with compiled eval."""
|
||||
pt = ParamTable()
|
||||
x = pt.add("x", 0.0)
|
||||
y = pt.add("y", 0.0)
|
||||
residuals = [x + y - Const(5.0), x - y - Const(1.0)]
|
||||
assert newton_solve(residuals, pt) is True
|
||||
assert abs(pt.get_value("x") - 3.0) < 1e-10
|
||||
assert abs(pt.get_value("y") - 2.0) < 1e-10
|
||||
|
||||
def test_quadratic(self):
|
||||
"""Solve x^2 - 4 = 0 starting from x=1."""
|
||||
pt = ParamTable()
|
||||
x = pt.add("x", 1.0)
|
||||
residuals = [x * x - Const(4.0)]
|
||||
assert newton_solve(residuals, pt) is True
|
||||
assert abs(pt.get_value("x") - 2.0) < 1e-10
|
||||
|
||||
def test_nonlinear_system(self):
|
||||
"""Compiled eval converges for a nonlinear system: xy=6, x+y=5."""
|
||||
pt = ParamTable()
|
||||
x = pt.add("x", 2.0)
|
||||
y = pt.add("y", 3.5)
|
||||
residuals = [x * y - Const(6.0), x + y - Const(5.0)]
|
||||
assert newton_solve(residuals, pt, max_iter=100) is True
|
||||
# Solutions are (2, 3) or (3, 2) — check they satisfy both equations
|
||||
xv, yv = pt.get_value("x"), pt.get_value("y")
|
||||
assert abs(xv * yv - 6.0) < 1e-10
|
||||
assert abs(xv + yv - 5.0) < 1e-10
|
||||
481
tests/test_constraints_phase2.py
Normal file
481
tests/test_constraints_phase2.py
Normal file
@@ -0,0 +1,481 @@
|
||||
"""Tests for Phase 2 constraint residual generation."""
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
from kindred_solver.constraints import (
|
||||
AngleConstraint,
|
||||
BallConstraint,
|
||||
CamConstraint,
|
||||
ConcentricConstraint,
|
||||
CylindricalConstraint,
|
||||
DistanceCylSphConstraint,
|
||||
GearConstraint,
|
||||
LineInPlaneConstraint,
|
||||
ParallelConstraint,
|
||||
PerpendicularConstraint,
|
||||
PlanarConstraint,
|
||||
PointInPlaneConstraint,
|
||||
PointOnLineConstraint,
|
||||
RackPinionConstraint,
|
||||
RevoluteConstraint,
|
||||
ScrewConstraint,
|
||||
SliderConstraint,
|
||||
SlotConstraint,
|
||||
TangentConstraint,
|
||||
UniversalConstraint,
|
||||
)
|
||||
from kindred_solver.entities import RigidBody
|
||||
from kindred_solver.params import ParamTable
|
||||
|
||||
ID_QUAT = (1.0, 0.0, 0.0, 0.0)
|
||||
# 90-deg about Y: Z-axis of body rotates to point along X
|
||||
_c = math.cos(math.pi / 4)
|
||||
_s = math.sin(math.pi / 4)
|
||||
ROT_90Y = (_c, 0.0, _s, 0.0)
|
||||
ROT_90Z = (_c, 0.0, 0.0, _s)
|
||||
|
||||
|
||||
# ── Point constraints ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPointOnLine:
|
||||
def test_on_line(self):
|
||||
"""Point at (0,0,5) is on Z-axis line through origin."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 5), (1, 0, 0, 0))
|
||||
c = PointOnLineConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT)
|
||||
env = pt.get_env()
|
||||
for r in c.residuals():
|
||||
assert abs(r.eval(env)) < 1e-10
|
||||
|
||||
def test_off_line(self):
|
||||
"""Point at (3,0,5) is NOT on Z-axis line through origin."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (3, 0, 5), (1, 0, 0, 0))
|
||||
c = PointOnLineConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT)
|
||||
env = pt.get_env()
|
||||
vals = [r.eval(env) for r in c.residuals()]
|
||||
assert any(abs(v) > 0.1 for v in vals)
|
||||
|
||||
def test_residual_count(self):
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = PointOnLineConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
assert len(c.residuals()) == 3
|
||||
|
||||
|
||||
class TestPointInPlane:
|
||||
def test_in_plane(self):
|
||||
"""Point at (3,4,0) is in XY plane through origin."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (3, 4, 0), (1, 0, 0, 0))
|
||||
c = PointInPlaneConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT)
|
||||
env = pt.get_env()
|
||||
assert abs(c.residuals()[0].eval(env)) < 1e-10
|
||||
|
||||
def test_above_plane(self):
|
||||
"""Point at (0,0,7) is 7 above XY plane."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 7), (1, 0, 0, 0))
|
||||
c = PointInPlaneConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT)
|
||||
env = pt.get_env()
|
||||
assert abs(c.residuals()[0].eval(env) - 7.0) < 1e-10
|
||||
|
||||
def test_with_offset(self):
|
||||
"""Point at (0,0,5) with offset=5 → residual 0."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 5), (1, 0, 0, 0))
|
||||
c = PointInPlaneConstraint(
|
||||
b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT, offset=5.0
|
||||
)
|
||||
env = pt.get_env()
|
||||
assert abs(c.residuals()[0].eval(env)) < 1e-10
|
||||
|
||||
def test_residual_count(self):
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = PointInPlaneConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
assert len(c.residuals()) == 1
|
||||
|
||||
|
||||
# ── Orientation constraints ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestParallel:
|
||||
def test_parallel_same(self):
|
||||
"""Both bodies with identity rotation → Z-axes parallel → residuals 0."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (5, 0, 0), (1, 0, 0, 0))
|
||||
c = ParallelConstraint(b1, ID_QUAT, b2, ID_QUAT)
|
||||
env = pt.get_env()
|
||||
for r in c.residuals():
|
||||
assert abs(r.eval(env)) < 1e-10
|
||||
|
||||
def test_not_parallel(self):
|
||||
"""One body rotated 90-deg about Y → Z-axes perpendicular."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (5, 0, 0), ROT_90Y)
|
||||
c = ParallelConstraint(b1, ID_QUAT, b2, ID_QUAT)
|
||||
env = pt.get_env()
|
||||
vals = [r.eval(env) for r in c.residuals()]
|
||||
assert any(abs(v) > 0.1 for v in vals)
|
||||
|
||||
def test_residual_count(self):
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = ParallelConstraint(b1, ID_QUAT, b2, ID_QUAT)
|
||||
assert len(c.residuals()) == 3
|
||||
|
||||
|
||||
class TestPerpendicular:
|
||||
def test_perpendicular(self):
|
||||
"""One body rotated 90-deg about Y → Z-axes perpendicular."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Y)
|
||||
c = PerpendicularConstraint(b1, ID_QUAT, b2, ID_QUAT)
|
||||
env = pt.get_env()
|
||||
assert abs(c.residuals()[0].eval(env)) < 1e-10
|
||||
|
||||
def test_not_perpendicular(self):
|
||||
"""Same orientation → not perpendicular."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = PerpendicularConstraint(b1, ID_QUAT, b2, ID_QUAT)
|
||||
env = pt.get_env()
|
||||
# dot(z,z) = 1 ≠ 0
|
||||
assert abs(c.residuals()[0].eval(env) - 1.0) < 1e-10
|
||||
|
||||
def test_residual_count(self):
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = PerpendicularConstraint(b1, ID_QUAT, b2, ID_QUAT)
|
||||
assert len(c.residuals()) == 1
|
||||
|
||||
|
||||
class TestAngle:
|
||||
def test_90_degrees(self):
|
||||
"""90-deg angle between Z-axes rotated 90-deg about Y."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Y)
|
||||
c = AngleConstraint(b1, ID_QUAT, b2, ID_QUAT, math.pi / 2)
|
||||
env = pt.get_env()
|
||||
assert abs(c.residuals()[0].eval(env)) < 1e-10
|
||||
|
||||
def test_0_degrees(self):
|
||||
"""0-deg angle, same orientation → cos(0)=1, dot=1 → residual 0."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = AngleConstraint(b1, ID_QUAT, b2, ID_QUAT, 0.0)
|
||||
env = pt.get_env()
|
||||
assert abs(c.residuals()[0].eval(env)) < 1e-10
|
||||
|
||||
def test_residual_count(self):
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = AngleConstraint(b1, ID_QUAT, b2, ID_QUAT, 1.0)
|
||||
assert len(c.residuals()) == 1
|
||||
|
||||
|
||||
# ── Axis/surface constraints ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestConcentric:
|
||||
def test_coaxial(self):
|
||||
"""Both on Z-axis → coaxial → residuals 0."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 5), (1, 0, 0, 0))
|
||||
c = ConcentricConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
env = pt.get_env()
|
||||
for r in c.residuals():
|
||||
assert abs(r.eval(env)) < 1e-10
|
||||
|
||||
def test_not_coaxial(self):
|
||||
"""Offset in X → not coaxial."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (5, 0, 0), (1, 0, 0, 0))
|
||||
c = ConcentricConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
env = pt.get_env()
|
||||
vals = [r.eval(env) for r in c.residuals()]
|
||||
assert any(abs(v) > 0.1 for v in vals)
|
||||
|
||||
def test_residual_count(self):
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = ConcentricConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
assert len(c.residuals()) == 6
|
||||
|
||||
|
||||
class TestTangent:
|
||||
def test_touching(self):
|
||||
"""Marker origins at same point → tangent."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = TangentConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
env = pt.get_env()
|
||||
assert abs(c.residuals()[0].eval(env)) < 1e-10
|
||||
|
||||
def test_separated(self):
|
||||
"""Separated along normal → non-zero residual."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 5), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = TangentConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
env = pt.get_env()
|
||||
assert abs(c.residuals()[0].eval(env) - 5.0) < 1e-10
|
||||
|
||||
def test_residual_count(self):
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = TangentConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
assert len(c.residuals()) == 1
|
||||
|
||||
|
||||
class TestPlanar:
|
||||
def test_coplanar(self):
|
||||
"""Same plane, same orientation → all residuals 0."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (5, 3, 0), (1, 0, 0, 0))
|
||||
c = PlanarConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
env = pt.get_env()
|
||||
for r in c.residuals():
|
||||
assert abs(r.eval(env)) < 1e-10
|
||||
|
||||
def test_with_offset(self):
|
||||
"""b_i at z=5, b_j at origin, normal=Z, offset=5.
|
||||
Signed distance = (p_i - p_j).n = 5, offset=5 → 5-5 = 0."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 5), (1, 0, 0, 0))
|
||||
c = PlanarConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT, offset=5.0)
|
||||
env = pt.get_env()
|
||||
for r in c.residuals():
|
||||
assert abs(r.eval(env)) < 1e-10
|
||||
|
||||
def test_residual_count(self):
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = PlanarConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
assert len(c.residuals()) == 4
|
||||
|
||||
|
||||
class TestLineInPlane:
|
||||
def test_in_plane(self):
|
||||
"""Line along X in XY plane → residuals 0."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
# b2 has Z-axis = (1,0,0) via 90-deg rotation about Y
|
||||
b2 = RigidBody("b", pt, (5, 0, 0), ROT_90Y)
|
||||
# Line = b2's Z-axis (which is world X), plane = b1's XY plane (normal=Z)
|
||||
c = LineInPlaneConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT)
|
||||
env = pt.get_env()
|
||||
for r in c.residuals():
|
||||
assert abs(r.eval(env)) < 1e-10
|
||||
|
||||
def test_residual_count(self):
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = LineInPlaneConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
assert len(c.residuals()) == 2
|
||||
|
||||
|
||||
# ── Kinematic joints ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBall:
|
||||
def test_same_as_coincident(self):
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = BallConstraint(b1, (0, 0, 0), b2, (0, 0, 0))
|
||||
env = pt.get_env()
|
||||
for r in c.residuals():
|
||||
assert abs(r.eval(env)) < 1e-10
|
||||
assert len(c.residuals()) == 3
|
||||
|
||||
|
||||
class TestRevolute:
|
||||
def test_satisfied(self):
|
||||
"""Same position, same Z-axis → satisfied."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Z) # rotated about Z — still parallel
|
||||
c = RevoluteConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
env = pt.get_env()
|
||||
for r in c.residuals():
|
||||
assert abs(r.eval(env)) < 1e-10
|
||||
|
||||
def test_residual_count(self):
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = RevoluteConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
assert len(c.residuals()) == 6
|
||||
|
||||
|
||||
class TestCylindrical:
|
||||
def test_on_axis(self):
|
||||
"""Same axis, displaced along Z → satisfied."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 10), (1, 0, 0, 0))
|
||||
c = CylindricalConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
env = pt.get_env()
|
||||
for r in c.residuals():
|
||||
assert abs(r.eval(env)) < 1e-10
|
||||
|
||||
def test_residual_count(self):
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = CylindricalConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
assert len(c.residuals()) == 6
|
||||
|
||||
|
||||
class TestSlider:
|
||||
def test_aligned(self):
|
||||
"""Same axis, no twist, displaced along Z → satisfied."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 10), (1, 0, 0, 0))
|
||||
c = SliderConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
env = pt.get_env()
|
||||
for r in c.residuals():
|
||||
assert abs(r.eval(env)) < 1e-10
|
||||
|
||||
def test_twisted(self):
|
||||
"""Rotated about Z → twist residual non-zero."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Z)
|
||||
c = SliderConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
env = pt.get_env()
|
||||
vals = [r.eval(env) for r in c.residuals()]
|
||||
# First 6 should be ~0 (parallel + on-line), but twist residual should be ~1
|
||||
assert abs(vals[6]) > 0.5
|
||||
|
||||
def test_residual_count(self):
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = SliderConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
assert len(c.residuals()) == 7
|
||||
|
||||
|
||||
class TestUniversal:
|
||||
def test_satisfied(self):
|
||||
"""Same origin, perpendicular Z-axes."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Y)
|
||||
c = UniversalConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
env = pt.get_env()
|
||||
for r in c.residuals():
|
||||
assert abs(r.eval(env)) < 1e-10
|
||||
|
||||
def test_residual_count(self):
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = UniversalConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
|
||||
assert len(c.residuals()) == 4
|
||||
|
||||
|
||||
class TestScrew:
|
||||
def test_residual_count(self):
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = ScrewConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT, pitch=10.0)
|
||||
assert len(c.residuals()) == 7
|
||||
|
||||
def test_zero_displacement_zero_rotation(self):
|
||||
"""Both at origin with identity rotation → all residuals 0."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = ScrewConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT, pitch=10.0)
|
||||
env = pt.get_env()
|
||||
for r in c.residuals():
|
||||
assert abs(r.eval(env)) < 1e-10
|
||||
|
||||
|
||||
# ── Mechanical constraints ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGear:
|
||||
def test_both_at_rest(self):
|
||||
"""Both at identity rotation → residual 0."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = GearConstraint(b1, ID_QUAT, b2, ID_QUAT, 1.0, 1.0)
|
||||
env = pt.get_env()
|
||||
assert abs(c.residuals()[0].eval(env)) < 1e-10
|
||||
|
||||
def test_residual_count(self):
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = GearConstraint(b1, ID_QUAT, b2, ID_QUAT, 1.0, 2.0)
|
||||
assert len(c.residuals()) == 1
|
||||
|
||||
|
||||
class TestRackPinion:
|
||||
def test_at_rest(self):
|
||||
"""Both at rest → residual 0."""
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = RackPinionConstraint(
|
||||
b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT, pitch_radius=5.0
|
||||
)
|
||||
env = pt.get_env()
|
||||
assert abs(c.residuals()[0].eval(env)) < 1e-10
|
||||
|
||||
def test_residual_count(self):
|
||||
pt = ParamTable()
|
||||
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
c = RackPinionConstraint(
|
||||
b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT, pitch_radius=1.0
|
||||
)
|
||||
assert len(c.residuals()) == 1
|
||||
|
||||
|
||||
# ── Stubs ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStubs:
|
||||
def test_cam(self):
|
||||
assert CamConstraint().residuals() == []
|
||||
|
||||
def test_slot(self):
|
||||
assert SlotConstraint().residuals() == []
|
||||
|
||||
def test_distance_cyl_sph(self):
|
||||
assert DistanceCylSphConstraint().residuals() == []
|
||||
1052
tests/test_decompose.py
Normal file
1052
tests/test_decompose.py
Normal file
File diff suppressed because it is too large
Load Diff
296
tests/test_diagnostics.py
Normal file
296
tests/test_diagnostics.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""Tests for per-entity DOF diagnostics and overconstrained detection."""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from kindred_solver.constraints import (
|
||||
CoincidentConstraint,
|
||||
CylindricalConstraint,
|
||||
DistancePointPointConstraint,
|
||||
FixedConstraint,
|
||||
ParallelConstraint,
|
||||
RevoluteConstraint,
|
||||
)
|
||||
from kindred_solver.diagnostics import (
|
||||
ConstraintDiag,
|
||||
EntityDOF,
|
||||
find_overconstrained,
|
||||
per_entity_dof,
|
||||
)
|
||||
from kindred_solver.entities import RigidBody
|
||||
from kindred_solver.params import ParamTable
|
||||
|
||||
|
||||
def _make_two_bodies(
|
||||
params,
|
||||
pos_a=(0, 0, 0),
|
||||
pos_b=(5, 0, 0),
|
||||
quat_a=(1, 0, 0, 0),
|
||||
quat_b=(1, 0, 0, 0),
|
||||
ground_a=True,
|
||||
ground_b=False,
|
||||
):
|
||||
body_a = RigidBody(
|
||||
"a", params, position=pos_a, quaternion=quat_a, grounded=ground_a
|
||||
)
|
||||
body_b = RigidBody(
|
||||
"b", params, position=pos_b, quaternion=quat_b, grounded=ground_b
|
||||
)
|
||||
return body_a, body_b
|
||||
|
||||
|
||||
def _build_residuals_and_ranges(constraint_objs, bodies, params):
|
||||
"""Build residuals list, quat norms, and residual_ranges."""
|
||||
all_residuals = []
|
||||
residual_ranges = []
|
||||
row = 0
|
||||
for i, obj in enumerate(constraint_objs):
|
||||
r = obj.residuals()
|
||||
n = len(r)
|
||||
residual_ranges.append((row, row + n, i))
|
||||
all_residuals.extend(r)
|
||||
row += n
|
||||
|
||||
for body in bodies.values():
|
||||
if not body.grounded:
|
||||
all_residuals.append(body.quat_norm_residual())
|
||||
|
||||
return all_residuals, residual_ranges
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Per-entity DOF tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestPerEntityDOF:
|
||||
"""Per-entity DOF computation."""
|
||||
|
||||
def test_unconstrained_body_6dof(self):
|
||||
"""Unconstrained non-grounded body has 6 DOF."""
|
||||
params = ParamTable()
|
||||
body = RigidBody(
|
||||
"b", params, position=(0, 0, 0), quaternion=(1, 0, 0, 0), grounded=False
|
||||
)
|
||||
bodies = {"b": body}
|
||||
|
||||
# Only quat norm constraint
|
||||
residuals = [body.quat_norm_residual()]
|
||||
|
||||
result = per_entity_dof(residuals, params, bodies)
|
||||
assert len(result) == 1
|
||||
assert result[0].entity_id == "b"
|
||||
assert result[0].remaining_dof == 6
|
||||
assert len(result[0].free_motions) == 6
|
||||
|
||||
def test_fixed_body_0dof(self):
|
||||
"""Body welded to ground has 0 DOF."""
|
||||
params = ParamTable()
|
||||
body_a, body_b = _make_two_bodies(params)
|
||||
bodies = {"a": body_a, "b": body_b}
|
||||
|
||||
c = FixedConstraint(
|
||||
body_a,
|
||||
(0, 0, 0),
|
||||
(1, 0, 0, 0),
|
||||
body_b,
|
||||
(0, 0, 0),
|
||||
(1, 0, 0, 0),
|
||||
)
|
||||
residuals, _ = _build_residuals_and_ranges([c], bodies, params)
|
||||
|
||||
result = per_entity_dof(residuals, params, bodies)
|
||||
# Only non-grounded body (b) reported
|
||||
assert len(result) == 1
|
||||
assert result[0].entity_id == "b"
|
||||
assert result[0].remaining_dof == 0
|
||||
assert len(result[0].free_motions) == 0
|
||||
|
||||
def test_revolute_1dof(self):
|
||||
"""Revolute joint leaves 1 DOF (rotation about Z)."""
|
||||
params = ParamTable()
|
||||
body_a, body_b = _make_two_bodies(params, pos_b=(0, 0, 0))
|
||||
bodies = {"a": body_a, "b": body_b}
|
||||
|
||||
c = RevoluteConstraint(
|
||||
body_a,
|
||||
(0, 0, 0),
|
||||
(1, 0, 0, 0),
|
||||
body_b,
|
||||
(0, 0, 0),
|
||||
(1, 0, 0, 0),
|
||||
)
|
||||
residuals, _ = _build_residuals_and_ranges([c], bodies, params)
|
||||
|
||||
result = per_entity_dof(residuals, params, bodies)
|
||||
assert len(result) == 1
|
||||
assert result[0].remaining_dof == 1
|
||||
# Should have one free motion that mentions rotation
|
||||
assert len(result[0].free_motions) == 1
|
||||
assert "rotation" in result[0].free_motions[0].lower()
|
||||
|
||||
def test_cylindrical_2dof(self):
|
||||
"""Cylindrical joint leaves 2 DOF (rotation about Z + translation along Z)."""
|
||||
params = ParamTable()
|
||||
body_a, body_b = _make_two_bodies(params, pos_b=(0, 0, 0))
|
||||
bodies = {"a": body_a, "b": body_b}
|
||||
|
||||
c = CylindricalConstraint(
|
||||
body_a,
|
||||
(0, 0, 0),
|
||||
(1, 0, 0, 0),
|
||||
body_b,
|
||||
(0, 0, 0),
|
||||
(1, 0, 0, 0),
|
||||
)
|
||||
residuals, _ = _build_residuals_and_ranges([c], bodies, params)
|
||||
|
||||
result = per_entity_dof(residuals, params, bodies)
|
||||
assert len(result) == 1
|
||||
assert result[0].remaining_dof == 2
|
||||
assert len(result[0].free_motions) == 2
|
||||
|
||||
def test_coincident_3dof(self):
|
||||
"""Coincident (ball) joint leaves 3 DOF (3 rotations)."""
|
||||
params = ParamTable()
|
||||
body_a, body_b = _make_two_bodies(params, pos_b=(0, 0, 0))
|
||||
bodies = {"a": body_a, "b": body_b}
|
||||
|
||||
c = CoincidentConstraint(body_a, (0, 0, 0), body_b, (0, 0, 0))
|
||||
residuals, _ = _build_residuals_and_ranges([c], bodies, params)
|
||||
|
||||
result = per_entity_dof(residuals, params, bodies)
|
||||
assert len(result) == 1
|
||||
assert result[0].remaining_dof == 3
|
||||
# All 3 should be rotations
|
||||
for motion in result[0].free_motions:
|
||||
assert "rotation" in motion.lower()
|
||||
|
||||
def test_no_constraints_6dof(self):
|
||||
"""No residuals at all gives 6 DOF."""
|
||||
params = ParamTable()
|
||||
body = RigidBody(
|
||||
"b", params, position=(0, 0, 0), quaternion=(1, 0, 0, 0), grounded=False
|
||||
)
|
||||
bodies = {"b": body}
|
||||
|
||||
result = per_entity_dof([], params, bodies)
|
||||
assert len(result) == 1
|
||||
assert result[0].remaining_dof == 6
|
||||
|
||||
def test_grounded_body_excluded(self):
|
||||
"""Grounded bodies are not reported."""
|
||||
params = ParamTable()
|
||||
body_a, body_b = _make_two_bodies(params)
|
||||
bodies = {"a": body_a, "b": body_b}
|
||||
|
||||
residuals = [body_b.quat_norm_residual()]
|
||||
result = per_entity_dof(residuals, params, bodies)
|
||||
|
||||
entity_ids = [r.entity_id for r in result]
|
||||
assert "a" not in entity_ids # grounded
|
||||
assert "b" in entity_ids
|
||||
|
||||
def test_multiple_bodies(self):
|
||||
"""Two free bodies: each gets its own DOF report."""
|
||||
params = ParamTable()
|
||||
body_g = RigidBody(
|
||||
"g", params, position=(0, 0, 0), quaternion=(1, 0, 0, 0), grounded=True
|
||||
)
|
||||
body_b = RigidBody(
|
||||
"b", params, position=(5, 0, 0), quaternion=(1, 0, 0, 0), grounded=False
|
||||
)
|
||||
body_c = RigidBody(
|
||||
"c", params, position=(10, 0, 0), quaternion=(1, 0, 0, 0), grounded=False
|
||||
)
|
||||
bodies = {"g": body_g, "b": body_b, "c": body_c}
|
||||
|
||||
# Fix b to ground, leave c unconstrained
|
||||
c_fix = FixedConstraint(
|
||||
body_g,
|
||||
(0, 0, 0),
|
||||
(1, 0, 0, 0),
|
||||
body_b,
|
||||
(0, 0, 0),
|
||||
(1, 0, 0, 0),
|
||||
)
|
||||
residuals, _ = _build_residuals_and_ranges([c_fix], bodies, params)
|
||||
|
||||
result = per_entity_dof(residuals, params, bodies)
|
||||
result_map = {r.entity_id: r for r in result}
|
||||
|
||||
assert result_map["b"].remaining_dof == 0
|
||||
assert result_map["c"].remaining_dof == 6
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Overconstrained detection tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestFindOverconstrained:
|
||||
"""Redundant and conflicting constraint detection."""
|
||||
|
||||
def test_well_constrained_no_diagnostics(self):
|
||||
"""Well-constrained system produces no diagnostics."""
|
||||
params = ParamTable()
|
||||
body_a, body_b = _make_two_bodies(params, pos_b=(0, 0, 0))
|
||||
bodies = {"a": body_a, "b": body_b}
|
||||
|
||||
c = FixedConstraint(
|
||||
body_a,
|
||||
(0, 0, 0),
|
||||
(1, 0, 0, 0),
|
||||
body_b,
|
||||
(0, 0, 0),
|
||||
(1, 0, 0, 0),
|
||||
)
|
||||
residuals, ranges = _build_residuals_and_ranges([c], bodies, params)
|
||||
|
||||
diags = find_overconstrained(residuals, params, ranges)
|
||||
assert len(diags) == 0
|
||||
|
||||
def test_duplicate_coincident_redundant(self):
|
||||
"""Duplicate coincident constraint is flagged as redundant."""
|
||||
params = ParamTable()
|
||||
body_a, body_b = _make_two_bodies(params, pos_b=(0, 0, 0))
|
||||
bodies = {"a": body_a, "b": body_b}
|
||||
|
||||
c1 = CoincidentConstraint(body_a, (0, 0, 0), body_b, (0, 0, 0))
|
||||
c2 = CoincidentConstraint(body_a, (0, 0, 0), body_b, (0, 0, 0))
|
||||
residuals, ranges = _build_residuals_and_ranges([c1, c2], bodies, params)
|
||||
|
||||
diags = find_overconstrained(residuals, params, ranges)
|
||||
assert len(diags) > 0
|
||||
# At least one should be redundant
|
||||
kinds = {d.kind for d in diags}
|
||||
assert "redundant" in kinds
|
||||
|
||||
def test_conflicting_distance(self):
|
||||
"""Distance constraint that can't be satisfied is flagged as conflicting."""
|
||||
params = ParamTable()
|
||||
body_a, body_b = _make_two_bodies(params, pos_b=(0, 0, 0))
|
||||
bodies = {"a": body_a, "b": body_b}
|
||||
|
||||
# Coincident forces distance=0, but distance constraint says 50
|
||||
c1 = CoincidentConstraint(body_a, (0, 0, 0), body_b, (0, 0, 0))
|
||||
c2 = DistancePointPointConstraint(
|
||||
body_a,
|
||||
(0, 0, 0),
|
||||
body_b,
|
||||
(0, 0, 0),
|
||||
distance=50.0,
|
||||
)
|
||||
residuals, ranges = _build_residuals_and_ranges([c1, c2], bodies, params)
|
||||
|
||||
diags = find_overconstrained(residuals, params, ranges)
|
||||
assert len(diags) > 0
|
||||
kinds = {d.kind for d in diags}
|
||||
assert "conflicting" in kinds
|
||||
|
||||
def test_empty_system_no_diagnostics(self):
|
||||
"""Empty system has no diagnostics."""
|
||||
params = ParamTable()
|
||||
diags = find_overconstrained([], params, [])
|
||||
assert len(diags) == 0
|
||||
253
tests/test_drag.py
Normal file
253
tests/test_drag.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""Regression tests for interactive drag.
|
||||
|
||||
These tests exercise the drag protocol at the solver-internals level,
|
||||
verifying that constraints remain enforced across drag steps when the
|
||||
pre-pass has been applied to cached residuals.
|
||||
|
||||
Bug scenario: single_equation_pass runs during pre_drag, analytically
|
||||
solving variables from upstream constraints and baking their values as
|
||||
constants into downstream residual expressions. When a drag step
|
||||
changes those variables, the cached residuals use stale constants and
|
||||
downstream constraints (e.g. Planar distance=0) stop being enforced.
|
||||
|
||||
Fix: skip single_equation_pass in the drag path. Only substitution_pass
|
||||
(which replaces genuinely grounded parameters) is safe to cache.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
from kindred_solver.constraints import (
|
||||
CoincidentConstraint,
|
||||
PlanarConstraint,
|
||||
RevoluteConstraint,
|
||||
)
|
||||
from kindred_solver.entities import RigidBody
|
||||
from kindred_solver.newton import newton_solve
|
||||
from kindred_solver.params import ParamTable
|
||||
from kindred_solver.prepass import single_equation_pass, substitution_pass
|
||||
|
||||
ID_QUAT = (1, 0, 0, 0)
|
||||
|
||||
|
||||
def _build_residuals(bodies, constraint_objs):
|
||||
"""Build raw residual list + quat groups (no prepass)."""
|
||||
all_residuals = []
|
||||
for c in constraint_objs:
|
||||
all_residuals.extend(c.residuals())
|
||||
|
||||
quat_groups = []
|
||||
for body in bodies:
|
||||
if not body.grounded:
|
||||
all_residuals.append(body.quat_norm_residual())
|
||||
quat_groups.append(body.quat_param_names())
|
||||
|
||||
return all_residuals, quat_groups
|
||||
|
||||
|
||||
def _eval_raw_residuals(bodies, constraint_objs, params):
|
||||
"""Evaluate original constraint residuals at current param values.
|
||||
|
||||
Returns the max absolute residual value — the ground truth for
|
||||
whether constraints are satisfied regardless of prepass state.
|
||||
"""
|
||||
raw, _ = _build_residuals(bodies, constraint_objs)
|
||||
env = params.get_env()
|
||||
return max(abs(r.eval(env)) for r in raw)
|
||||
|
||||
|
||||
class TestPrepassDragRegression:
|
||||
"""single_equation_pass bakes stale values that break drag.
|
||||
|
||||
Setup: ground --Revolute--> arm --Planar(d=0)--> plate
|
||||
|
||||
The Revolute pins arm's origin to ground (fixes arm/tx, arm/ty,
|
||||
arm/tz via single_equation_pass). The Planar keeps plate coplanar
|
||||
with arm. After prepass, the Planar residual has arm's position
|
||||
baked as Const(0.0).
|
||||
|
||||
During drag: arm/tz is set to 5.0. Because arm/tz is marked fixed
|
||||
by prepass, Newton can't correct it, AND the Planar residual still
|
||||
uses Const(0.0) instead of the live value 5.0. The Revolute
|
||||
constraint (arm at origin) is silently violated.
|
||||
"""
|
||||
|
||||
def _setup(self):
|
||||
pt = ParamTable()
|
||||
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
arm = RigidBody("arm", pt, (10, 0, 0), ID_QUAT)
|
||||
plate = RigidBody("plate", pt, (10, 5, 0), ID_QUAT)
|
||||
|
||||
constraints = [
|
||||
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, arm, (0, 0, 0), ID_QUAT),
|
||||
PlanarConstraint(arm, (0, 0, 0), ID_QUAT, plate, (0, 0, 0), ID_QUAT, offset=0.0),
|
||||
]
|
||||
bodies = [ground, arm, plate]
|
||||
return pt, bodies, constraints
|
||||
|
||||
def test_bug_stale_constants_after_single_equation_pass(self):
|
||||
"""Document the bug: prepass bakes arm/tz=0, drag breaks constraints."""
|
||||
pt, bodies, constraints = self._setup()
|
||||
raw_residuals, quat_groups = _build_residuals(bodies, constraints)
|
||||
|
||||
# Simulate OLD pre_drag: substitution + single_equation_pass
|
||||
residuals = substitution_pass(raw_residuals, pt)
|
||||
residuals = single_equation_pass(residuals, pt)
|
||||
|
||||
ok = newton_solve(residuals, pt, quat_groups=quat_groups, max_iter=100, tol=1e-10)
|
||||
assert ok
|
||||
|
||||
# Verify prepass fixed arm's position params
|
||||
assert pt.is_fixed("arm/tx")
|
||||
assert pt.is_fixed("arm/ty")
|
||||
assert pt.is_fixed("arm/tz")
|
||||
|
||||
# Simulate drag: move arm up (set_value, as drag_step does)
|
||||
pt.set_value("arm/tz", 5.0)
|
||||
pt.set_value("plate/tz", 5.0) # initial guess near drag
|
||||
|
||||
ok = newton_solve(residuals, pt, quat_groups=quat_groups, max_iter=100, tol=1e-10)
|
||||
# Solver "converges" on the stale cached residuals
|
||||
assert ok
|
||||
|
||||
# But the TRUE constraints are violated: arm should be at z=0
|
||||
# (Revolute pins it to ground) yet it's at z=5
|
||||
max_err = _eval_raw_residuals(bodies, constraints, pt)
|
||||
assert max_err > 1.0, (
|
||||
f"Expected large raw residual violation, got {max_err:.6e}. "
|
||||
"The bug should cause the Revolute z-residual to be ~5.0"
|
||||
)
|
||||
|
||||
def test_fix_no_single_equation_pass_for_drag(self):
|
||||
"""With the fix: skip single_equation_pass, constraints hold."""
|
||||
pt, bodies, constraints = self._setup()
|
||||
raw_residuals, quat_groups = _build_residuals(bodies, constraints)
|
||||
|
||||
# Simulate FIXED pre_drag: substitution only
|
||||
residuals = substitution_pass(raw_residuals, pt)
|
||||
|
||||
ok = newton_solve(residuals, pt, quat_groups=quat_groups, max_iter=100, tol=1e-10)
|
||||
assert ok
|
||||
|
||||
# arm/tz should NOT be fixed
|
||||
assert not pt.is_fixed("arm/tz")
|
||||
|
||||
# Simulate drag: move arm up
|
||||
pt.set_value("arm/tz", 5.0)
|
||||
pt.set_value("plate/tz", 5.0)
|
||||
|
||||
ok = newton_solve(residuals, pt, quat_groups=quat_groups, max_iter=100, tol=1e-10)
|
||||
assert ok
|
||||
|
||||
# Newton pulls arm back to z=0 (Revolute enforced) and plate follows
|
||||
max_err = _eval_raw_residuals(bodies, constraints, pt)
|
||||
assert max_err < 1e-8, f"Raw residual violation {max_err:.6e} — constraints not satisfied"
|
||||
|
||||
|
||||
class TestCoincidentPlanarDragRegression:
|
||||
"""Coincident upstream + Planar downstream — same bug class.
|
||||
|
||||
ground --Coincident--> bracket --Planar(d=0)--> plate
|
||||
|
||||
Coincident fixes bracket/tx,ty,tz. After prepass, the Planar
|
||||
residual has bracket's position baked. Drag moves bracket;
|
||||
the Planar uses stale constants.
|
||||
"""
|
||||
|
||||
def _setup(self):
|
||||
pt = ParamTable()
|
||||
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
bracket = RigidBody("bracket", pt, (0, 0, 0), ID_QUAT)
|
||||
plate = RigidBody("plate", pt, (10, 5, 0), ID_QUAT)
|
||||
|
||||
constraints = [
|
||||
CoincidentConstraint(ground, (0, 0, 0), bracket, (0, 0, 0)),
|
||||
PlanarConstraint(bracket, (0, 0, 0), ID_QUAT, plate, (0, 0, 0), ID_QUAT, offset=0.0),
|
||||
]
|
||||
bodies = [ground, bracket, plate]
|
||||
return pt, bodies, constraints
|
||||
|
||||
def test_bug_coincident_planar(self):
|
||||
"""Prepass fixes bracket/tz, Planar uses stale constant during drag."""
|
||||
pt, bodies, constraints = self._setup()
|
||||
raw, qg = _build_residuals(bodies, constraints)
|
||||
|
||||
residuals = substitution_pass(raw, pt)
|
||||
residuals = single_equation_pass(residuals, pt)
|
||||
|
||||
ok = newton_solve(residuals, pt, quat_groups=qg, max_iter=100, tol=1e-10)
|
||||
assert ok
|
||||
assert pt.is_fixed("bracket/tz")
|
||||
|
||||
# Drag bracket up
|
||||
pt.set_value("bracket/tz", 5.0)
|
||||
pt.set_value("plate/tz", 5.0)
|
||||
|
||||
ok = newton_solve(residuals, pt, quat_groups=qg, max_iter=100, tol=1e-10)
|
||||
assert ok
|
||||
|
||||
# True constraints violated
|
||||
max_err = _eval_raw_residuals(bodies, constraints, pt)
|
||||
assert max_err > 1.0, f"Expected raw violation from stale prepass, got {max_err:.6e}"
|
||||
|
||||
def test_fix_coincident_planar(self):
|
||||
"""With the fix: constraints satisfied after drag."""
|
||||
pt, bodies, constraints = self._setup()
|
||||
raw, qg = _build_residuals(bodies, constraints)
|
||||
|
||||
residuals = substitution_pass(raw, pt)
|
||||
# No single_equation_pass
|
||||
|
||||
ok = newton_solve(residuals, pt, quat_groups=qg, max_iter=100, tol=1e-10)
|
||||
assert ok
|
||||
assert not pt.is_fixed("bracket/tz")
|
||||
|
||||
# Drag bracket up
|
||||
pt.set_value("bracket/tz", 5.0)
|
||||
pt.set_value("plate/tz", 5.0)
|
||||
|
||||
ok = newton_solve(residuals, pt, quat_groups=qg, max_iter=100, tol=1e-10)
|
||||
assert ok
|
||||
|
||||
max_err = _eval_raw_residuals(bodies, constraints, pt)
|
||||
assert max_err < 1e-8, f"Raw residual violation {max_err:.6e} — constraints not satisfied"
|
||||
|
||||
|
||||
class TestDragDoesNotBreakStaticSolve:
|
||||
"""Verify that the static solve path (with single_equation_pass) still works.
|
||||
|
||||
The fix only affects pre_drag — the static solve() path continues to
|
||||
use single_equation_pass for faster convergence.
|
||||
"""
|
||||
|
||||
def test_static_solve_still_uses_prepass(self):
|
||||
"""Static solve with single_equation_pass converges correctly."""
|
||||
pt = ParamTable()
|
||||
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
arm = RigidBody("arm", pt, (10, 0, 0), ID_QUAT)
|
||||
plate = RigidBody("plate", pt, (10, 5, 8), ID_QUAT)
|
||||
|
||||
constraints = [
|
||||
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, arm, (0, 0, 0), ID_QUAT),
|
||||
PlanarConstraint(arm, (0, 0, 0), ID_QUAT, plate, (0, 0, 0), ID_QUAT, offset=0.0),
|
||||
]
|
||||
bodies = [ground, arm, plate]
|
||||
raw, qg = _build_residuals(bodies, constraints)
|
||||
|
||||
# Full prepass (static solve path)
|
||||
residuals = substitution_pass(raw, pt)
|
||||
residuals = single_equation_pass(residuals, pt)
|
||||
|
||||
ok = newton_solve(residuals, pt, quat_groups=qg, max_iter=100, tol=1e-10)
|
||||
assert ok
|
||||
|
||||
# All raw constraints satisfied
|
||||
max_err = _eval_raw_residuals(bodies, constraints, pt)
|
||||
assert max_err < 1e-8
|
||||
|
||||
# arm at origin (Revolute), plate coplanar (z=0)
|
||||
env = pt.get_env()
|
||||
assert abs(env["arm/tx"]) < 1e-8
|
||||
assert abs(env["arm/ty"]) < 1e-8
|
||||
assert abs(env["arm/tz"]) < 1e-8
|
||||
assert abs(env["plate/tz"]) < 1e-8
|
||||
189
tests/test_geometry.py
Normal file
189
tests/test_geometry.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Tests for geometry helpers."""
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
from kindred_solver.entities import RigidBody
|
||||
from kindred_solver.expr import Const, Var
|
||||
from kindred_solver.geometry import (
|
||||
cross3,
|
||||
dot3,
|
||||
marker_x_axis,
|
||||
marker_y_axis,
|
||||
marker_z_axis,
|
||||
point_line_perp_components,
|
||||
point_plane_distance,
|
||||
sub3,
|
||||
)
|
||||
from kindred_solver.params import ParamTable
|
||||
|
||||
IDENTITY_QUAT = (1.0, 0.0, 0.0, 0.0)
|
||||
# 90-deg about Z: (cos45, 0, 0, sin45)
|
||||
_c = math.cos(math.pi / 4)
|
||||
_s = math.sin(math.pi / 4)
|
||||
ROT_90Z_QUAT = (_c, 0.0, 0.0, _s)
|
||||
|
||||
|
||||
class TestDot3:
|
||||
def test_parallel(self):
|
||||
a = (Const(1.0), Const(0.0), Const(0.0))
|
||||
b = (Const(1.0), Const(0.0), Const(0.0))
|
||||
assert abs(dot3(a, b).eval({}) - 1.0) < 1e-10
|
||||
|
||||
def test_perpendicular(self):
|
||||
a = (Const(1.0), Const(0.0), Const(0.0))
|
||||
b = (Const(0.0), Const(1.0), Const(0.0))
|
||||
assert abs(dot3(a, b).eval({})) < 1e-10
|
||||
|
||||
def test_general(self):
|
||||
a = (Const(1.0), Const(2.0), Const(3.0))
|
||||
b = (Const(4.0), Const(5.0), Const(6.0))
|
||||
# 1*4 + 2*5 + 3*6 = 32
|
||||
assert abs(dot3(a, b).eval({}) - 32.0) < 1e-10
|
||||
|
||||
|
||||
class TestCross3:
|
||||
def test_x_cross_y(self):
|
||||
x = (Const(1.0), Const(0.0), Const(0.0))
|
||||
y = (Const(0.0), Const(1.0), Const(0.0))
|
||||
cx, cy, cz = cross3(x, y)
|
||||
assert abs(cx.eval({})) < 1e-10
|
||||
assert abs(cy.eval({})) < 1e-10
|
||||
assert abs(cz.eval({}) - 1.0) < 1e-10
|
||||
|
||||
def test_parallel_is_zero(self):
|
||||
a = (Const(2.0), Const(3.0), Const(4.0))
|
||||
b = (Const(4.0), Const(6.0), Const(8.0))
|
||||
cx, cy, cz = cross3(a, b)
|
||||
assert abs(cx.eval({})) < 1e-10
|
||||
assert abs(cy.eval({})) < 1e-10
|
||||
assert abs(cz.eval({})) < 1e-10
|
||||
|
||||
|
||||
class TestSub3:
|
||||
def test_basic(self):
|
||||
a = (Const(5.0), Const(3.0), Const(1.0))
|
||||
b = (Const(1.0), Const(2.0), Const(3.0))
|
||||
dx, dy, dz = sub3(a, b)
|
||||
assert abs(dx.eval({}) - 4.0) < 1e-10
|
||||
assert abs(dy.eval({}) - 1.0) < 1e-10
|
||||
assert abs(dz.eval({}) - (-2.0)) < 1e-10
|
||||
|
||||
|
||||
class TestMarkerAxes:
|
||||
def test_identity_z(self):
|
||||
"""Identity body + identity marker → Z = (0,0,1)."""
|
||||
pt = ParamTable()
|
||||
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
zx, zy, zz = marker_z_axis(body, IDENTITY_QUAT)
|
||||
env = pt.get_env()
|
||||
assert abs(zx.eval(env)) < 1e-10
|
||||
assert abs(zy.eval(env)) < 1e-10
|
||||
assert abs(zz.eval(env) - 1.0) < 1e-10
|
||||
|
||||
def test_identity_x(self):
|
||||
"""Identity body + identity marker → X = (1,0,0)."""
|
||||
pt = ParamTable()
|
||||
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
xx, xy, xz = marker_x_axis(body, IDENTITY_QUAT)
|
||||
env = pt.get_env()
|
||||
assert abs(xx.eval(env) - 1.0) < 1e-10
|
||||
assert abs(xy.eval(env)) < 1e-10
|
||||
assert abs(xz.eval(env)) < 1e-10
|
||||
|
||||
def test_identity_y(self):
|
||||
"""Identity body + identity marker → Y = (0,1,0)."""
|
||||
pt = ParamTable()
|
||||
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
yx, yy, yz = marker_y_axis(body, IDENTITY_QUAT)
|
||||
env = pt.get_env()
|
||||
assert abs(yx.eval(env)) < 1e-10
|
||||
assert abs(yy.eval(env) - 1.0) < 1e-10
|
||||
assert abs(yz.eval(env)) < 1e-10
|
||||
|
||||
def test_rotated_body_z(self):
|
||||
"""Body rotated 90-deg about Z → Z-axis still (0,0,1)."""
|
||||
pt = ParamTable()
|
||||
body = RigidBody("p", pt, (0, 0, 0), ROT_90Z_QUAT)
|
||||
zx, zy, zz = marker_z_axis(body, IDENTITY_QUAT)
|
||||
env = pt.get_env()
|
||||
assert abs(zx.eval(env)) < 1e-10
|
||||
assert abs(zy.eval(env)) < 1e-10
|
||||
assert abs(zz.eval(env) - 1.0) < 1e-10
|
||||
|
||||
def test_rotated_body_x(self):
|
||||
"""Body rotated 90-deg about Z → X-axis becomes (0,1,0)."""
|
||||
pt = ParamTable()
|
||||
body = RigidBody("p", pt, (0, 0, 0), ROT_90Z_QUAT)
|
||||
xx, xy, xz = marker_x_axis(body, IDENTITY_QUAT)
|
||||
env = pt.get_env()
|
||||
assert abs(xx.eval(env)) < 1e-10
|
||||
assert abs(xy.eval(env) - 1.0) < 1e-10
|
||||
assert abs(xz.eval(env)) < 1e-10
|
||||
|
||||
def test_marker_rotation(self):
|
||||
"""Identity body + marker rotated 90-deg about Z → Z still (0,0,1)."""
|
||||
pt = ParamTable()
|
||||
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
zx, zy, zz = marker_z_axis(body, ROT_90Z_QUAT)
|
||||
env = pt.get_env()
|
||||
assert abs(zx.eval(env)) < 1e-10
|
||||
assert abs(zy.eval(env)) < 1e-10
|
||||
assert abs(zz.eval(env) - 1.0) < 1e-10
|
||||
|
||||
def test_marker_rotation_x_axis(self):
|
||||
"""Identity body + marker rotated 90-deg about Z → X becomes (0,1,0)."""
|
||||
pt = ParamTable()
|
||||
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
xx, xy, xz = marker_x_axis(body, ROT_90Z_QUAT)
|
||||
env = pt.get_env()
|
||||
assert abs(xx.eval(env)) < 1e-10
|
||||
assert abs(xy.eval(env) - 1.0) < 1e-10
|
||||
assert abs(xz.eval(env)) < 1e-10
|
||||
|
||||
def test_differentiable(self):
|
||||
"""Marker axes are differentiable w.r.t. body quat params."""
|
||||
pt = ParamTable()
|
||||
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
|
||||
zx, zy, zz = marker_z_axis(body, IDENTITY_QUAT)
|
||||
# Should not raise
|
||||
dzx = zx.diff("p/qz").simplify()
|
||||
env = pt.get_env()
|
||||
dzx.eval(env) # Should be evaluable
|
||||
|
||||
|
||||
class TestPointPlaneDistance:
|
||||
def test_on_plane(self):
|
||||
pt = (Const(1.0), Const(2.0), Const(0.0))
|
||||
origin = (Const(0.0), Const(0.0), Const(0.0))
|
||||
normal = (Const(0.0), Const(0.0), Const(1.0))
|
||||
d = point_plane_distance(pt, origin, normal)
|
||||
assert abs(d.eval({})) < 1e-10
|
||||
|
||||
def test_above_plane(self):
|
||||
pt = (Const(1.0), Const(2.0), Const(5.0))
|
||||
origin = (Const(0.0), Const(0.0), Const(0.0))
|
||||
normal = (Const(0.0), Const(0.0), Const(1.0))
|
||||
d = point_plane_distance(pt, origin, normal)
|
||||
assert abs(d.eval({}) - 5.0) < 1e-10
|
||||
|
||||
|
||||
class TestPointLinePerp:
|
||||
def test_on_line(self):
|
||||
pt = (Const(0.0), Const(0.0), Const(5.0))
|
||||
origin = (Const(0.0), Const(0.0), Const(0.0))
|
||||
direction = (Const(0.0), Const(0.0), Const(1.0))
|
||||
cx, cy, cz = point_line_perp_components(pt, origin, direction)
|
||||
assert abs(cx.eval({})) < 1e-10
|
||||
assert abs(cy.eval({})) < 1e-10
|
||||
assert abs(cz.eval({})) < 1e-10
|
||||
|
||||
def test_off_line(self):
|
||||
pt = (Const(3.0), Const(0.0), Const(0.0))
|
||||
origin = (Const(0.0), Const(0.0), Const(0.0))
|
||||
direction = (Const(0.0), Const(0.0), Const(1.0))
|
||||
cx, cy, cz = point_line_perp_components(pt, origin, direction)
|
||||
# d = (3,0,0), dir = (0,0,1), d x dir = (0*1-0*0, 0*0-3*1, 3*0-0*0) = (0,-3,0)
|
||||
assert abs(cx.eval({})) < 1e-10
|
||||
assert abs(cy.eval({}) - (-3.0)) < 1e-10
|
||||
assert abs(cz.eval({})) < 1e-10
|
||||
614
tests/test_joints.py
Normal file
614
tests/test_joints.py
Normal file
@@ -0,0 +1,614 @@
|
||||
"""Integration tests for kinematic joint constraints.
|
||||
|
||||
These tests exercise the full solve pipeline (constraint → residuals →
|
||||
pre-pass → Newton / BFGS) for multi-body systems with various joint types.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
from kindred_solver.constraints import (
|
||||
BallConstraint,
|
||||
CoincidentConstraint,
|
||||
CylindricalConstraint,
|
||||
GearConstraint,
|
||||
ParallelConstraint,
|
||||
PerpendicularConstraint,
|
||||
PlanarConstraint,
|
||||
PointInPlaneConstraint,
|
||||
PointOnLineConstraint,
|
||||
RackPinionConstraint,
|
||||
RevoluteConstraint,
|
||||
ScrewConstraint,
|
||||
SliderConstraint,
|
||||
UniversalConstraint,
|
||||
)
|
||||
from kindred_solver.dof import count_dof
|
||||
from kindred_solver.entities import RigidBody
|
||||
from kindred_solver.newton import newton_solve
|
||||
from kindred_solver.params import ParamTable
|
||||
from kindred_solver.prepass import single_equation_pass, substitution_pass
|
||||
|
||||
ID_QUAT = (1, 0, 0, 0)
|
||||
# 90° about Z: (cos(45°), 0, 0, sin(45°))
|
||||
c45 = math.cos(math.pi / 4)
|
||||
s45 = math.sin(math.pi / 4)
|
||||
ROT_90Z = (c45, 0, 0, s45)
|
||||
# 90° about Y
|
||||
ROT_90Y = (c45, 0, s45, 0)
|
||||
# 90° about X
|
||||
ROT_90X = (c45, s45, 0, 0)
|
||||
|
||||
|
||||
def _solve(bodies, constraint_objs):
|
||||
"""Run the full solve pipeline. Returns (converged, params, bodies)."""
|
||||
pt = bodies[0].tx # all bodies share the same ParamTable via Var._name
|
||||
# Actually, we need the ParamTable object. Get it from the first body.
|
||||
# The Var objects store names, but we need the table. We'll reconstruct.
|
||||
# Better approach: caller passes pt.
|
||||
|
||||
raise NotImplementedError("Use _solve_with_pt instead")
|
||||
|
||||
|
||||
def _solve_with_pt(pt, bodies, constraint_objs):
|
||||
"""Run the full solve pipeline with explicit ParamTable."""
|
||||
all_residuals = []
|
||||
for c in constraint_objs:
|
||||
all_residuals.extend(c.residuals())
|
||||
|
||||
quat_groups = []
|
||||
for body in bodies:
|
||||
if not body.grounded:
|
||||
all_residuals.append(body.quat_norm_residual())
|
||||
quat_groups.append(body.quat_param_names())
|
||||
|
||||
all_residuals = substitution_pass(all_residuals, pt)
|
||||
all_residuals = single_equation_pass(all_residuals, pt)
|
||||
|
||||
converged = newton_solve(
|
||||
all_residuals, pt, quat_groups=quat_groups, max_iter=100, tol=1e-10
|
||||
)
|
||||
return converged, all_residuals
|
||||
|
||||
|
||||
def _dof(pt, bodies, constraint_objs):
|
||||
"""Count DOF for a system."""
|
||||
all_residuals = []
|
||||
for c in constraint_objs:
|
||||
all_residuals.extend(c.residuals())
|
||||
for body in bodies:
|
||||
if not body.grounded:
|
||||
all_residuals.append(body.quat_norm_residual())
|
||||
all_residuals = substitution_pass(all_residuals, pt)
|
||||
return count_dof(all_residuals, pt)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Single-joint DOF counting tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestJointDOF:
|
||||
"""Verify each joint type removes the expected number of DOF.
|
||||
|
||||
Setup: ground body + 1 free body (6 DOF) with a single joint.
|
||||
"""
|
||||
|
||||
def _setup(self, pos_b=(0, 0, 0), quat_b=ID_QUAT):
|
||||
pt = ParamTable()
|
||||
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
b = RigidBody("b", pt, pos_b, quat_b)
|
||||
return pt, a, b
|
||||
|
||||
def test_ball_3dof(self):
|
||||
"""Ball joint: 6 - 3 = 3 DOF (3 rotation)."""
|
||||
pt, a, b = self._setup()
|
||||
constraints = [BallConstraint(a, (0, 0, 0), b, (0, 0, 0))]
|
||||
assert _dof(pt, [a, b], constraints) == 3
|
||||
|
||||
def test_revolute_1dof(self):
|
||||
"""Revolute: 6 - 5 = 1 DOF (rotation about Z)."""
|
||||
pt, a, b = self._setup()
|
||||
constraints = [RevoluteConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
|
||||
assert _dof(pt, [a, b], constraints) == 1
|
||||
|
||||
def test_cylindrical_2dof(self):
|
||||
"""Cylindrical: 6 - 4 = 2 DOF (rotation + translation along Z)."""
|
||||
pt, a, b = self._setup()
|
||||
constraints = [
|
||||
CylindricalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
|
||||
]
|
||||
assert _dof(pt, [a, b], constraints) == 2
|
||||
|
||||
def test_slider_1dof(self):
|
||||
"""Slider: 6 - 5 = 1 DOF (translation along Z)."""
|
||||
pt, a, b = self._setup()
|
||||
constraints = [SliderConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
|
||||
assert _dof(pt, [a, b], constraints) == 1
|
||||
|
||||
def test_universal_2dof(self):
|
||||
"""Universal: 6 - 4 = 2 DOF (rotation about each body's Z)."""
|
||||
pt, a, b = self._setup(quat_b=ROT_90X)
|
||||
constraints = [
|
||||
UniversalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
|
||||
]
|
||||
assert _dof(pt, [a, b], constraints) == 2
|
||||
|
||||
def test_screw_1dof(self):
|
||||
"""Screw: 6 - 5 = 1 DOF (helical motion)."""
|
||||
pt, a, b = self._setup()
|
||||
constraints = [
|
||||
ScrewConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT, pitch=10.0)
|
||||
]
|
||||
assert _dof(pt, [a, b], constraints) == 1
|
||||
|
||||
def test_parallel_4dof(self):
|
||||
"""Parallel: 6 - 2 = 4 DOF."""
|
||||
pt, a, b = self._setup()
|
||||
constraints = [ParallelConstraint(a, ID_QUAT, b, ID_QUAT)]
|
||||
assert _dof(pt, [a, b], constraints) == 4
|
||||
|
||||
def test_perpendicular_5dof(self):
|
||||
"""Perpendicular: 6 - 1 = 5 DOF."""
|
||||
pt, a, b = self._setup(quat_b=ROT_90X)
|
||||
constraints = [PerpendicularConstraint(a, ID_QUAT, b, ID_QUAT)]
|
||||
assert _dof(pt, [a, b], constraints) == 5
|
||||
|
||||
def test_point_on_line_4dof(self):
|
||||
"""PointOnLine: 6 - 2 = 4 DOF."""
|
||||
pt, a, b = self._setup()
|
||||
constraints = [
|
||||
PointOnLineConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
|
||||
]
|
||||
assert _dof(pt, [a, b], constraints) == 4
|
||||
|
||||
def test_point_in_plane_5dof(self):
|
||||
"""PointInPlane: 6 - 1 = 5 DOF."""
|
||||
pt, a, b = self._setup()
|
||||
constraints = [
|
||||
PointInPlaneConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
|
||||
]
|
||||
assert _dof(pt, [a, b], constraints) == 5
|
||||
|
||||
def test_planar_3dof(self):
|
||||
"""Planar: 6 - 3 = 3 DOF (2 translation in plane + 1 rotation about normal)."""
|
||||
pt, a, b = self._setup()
|
||||
constraints = [PlanarConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
|
||||
assert _dof(pt, [a, b], constraints) == 3
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Solve convergence tests — single joints from displaced initial conditions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestJointSolve:
|
||||
"""Newton converges to a valid configuration from displaced starting points."""
|
||||
|
||||
def test_revolute_displaced(self):
|
||||
"""Revolute joint: body B starts displaced, should converge to hinge position."""
|
||||
pt = ParamTable()
|
||||
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
b = RigidBody("b", pt, (3, 4, 5), ID_QUAT) # displaced
|
||||
|
||||
constraints = [RevoluteConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
|
||||
converged, _ = _solve_with_pt(pt, [a, b], constraints)
|
||||
assert converged
|
||||
|
||||
env = pt.get_env()
|
||||
pos = b.extract_position(env)
|
||||
# Coincident origins → position should be at origin
|
||||
assert abs(pos[0]) < 1e-8
|
||||
assert abs(pos[1]) < 1e-8
|
||||
assert abs(pos[2]) < 1e-8
|
||||
|
||||
def test_cylindrical_displaced(self):
|
||||
"""Cylindrical joint: body B can slide along Z but must be on axis."""
|
||||
pt = ParamTable()
|
||||
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
b = RigidBody("b", pt, (3, 4, 7), ID_QUAT) # off-axis
|
||||
|
||||
constraints = [
|
||||
CylindricalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
|
||||
]
|
||||
converged, _ = _solve_with_pt(pt, [a, b], constraints)
|
||||
assert converged
|
||||
|
||||
env = pt.get_env()
|
||||
pos = b.extract_position(env)
|
||||
# X and Y should be zero (on axis), Z can be anything
|
||||
assert abs(pos[0]) < 1e-8
|
||||
assert abs(pos[1]) < 1e-8
|
||||
|
||||
def test_slider_displaced(self):
|
||||
"""Slider: body B can translate along Z only."""
|
||||
pt = ParamTable()
|
||||
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
b = RigidBody("b", pt, (2, 3, 5), ID_QUAT) # displaced
|
||||
|
||||
constraints = [SliderConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
|
||||
converged, _ = _solve_with_pt(pt, [a, b], constraints)
|
||||
assert converged
|
||||
|
||||
env = pt.get_env()
|
||||
pos = b.extract_position(env)
|
||||
# X and Y should be zero (on axis), Z free
|
||||
assert abs(pos[0]) < 1e-8
|
||||
assert abs(pos[1]) < 1e-8
|
||||
|
||||
def test_ball_displaced(self):
|
||||
"""Ball joint: body B moves so marker origins coincide.
|
||||
|
||||
Ball has 3 rotation DOF free, so we can only verify the
|
||||
world-frame marker points match, not the body position directly.
|
||||
"""
|
||||
pt = ParamTable()
|
||||
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
b = RigidBody("b", pt, (5, 5, 5), ID_QUAT)
|
||||
|
||||
constraints = [BallConstraint(a, (1, 0, 0), b, (-1, 0, 0))]
|
||||
converged, _ = _solve_with_pt(pt, [a, b], constraints)
|
||||
assert converged
|
||||
|
||||
env = pt.get_env()
|
||||
# Verify marker world points match
|
||||
wp_a = a.world_point(1, 0, 0)
|
||||
wp_b = b.world_point(-1, 0, 0)
|
||||
for ea, eb in zip(wp_a, wp_b):
|
||||
assert abs(ea.eval(env) - eb.eval(env)) < 1e-8
|
||||
|
||||
def test_universal_displaced(self):
|
||||
"""Universal joint: coincident origins + perpendicular Z-axes."""
|
||||
pt = ParamTable()
|
||||
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
# Start B with Z-axis along X (90° about Y) — perpendicular to A's Z
|
||||
b = RigidBody("b", pt, (3, 4, 5), ROT_90Y)
|
||||
|
||||
constraints = [
|
||||
UniversalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
|
||||
]
|
||||
converged, _ = _solve_with_pt(pt, [a, b], constraints)
|
||||
assert converged
|
||||
|
||||
env = pt.get_env()
|
||||
pos = b.extract_position(env)
|
||||
assert abs(pos[0]) < 1e-8
|
||||
assert abs(pos[1]) < 1e-8
|
||||
assert abs(pos[2]) < 1e-8
|
||||
|
||||
def test_point_on_line_solve(self):
|
||||
"""Point on line: body B's marker origin constrained to line along Z.
|
||||
|
||||
Under-constrained system (4 DOF remain), so we verify the constraint
|
||||
residuals are satisfied rather than expecting specific positions.
|
||||
"""
|
||||
pt = ParamTable()
|
||||
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
b = RigidBody("b", pt, (5, 3, 7), ID_QUAT)
|
||||
|
||||
constraints = [
|
||||
PointOnLineConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
|
||||
]
|
||||
converged, residuals = _solve_with_pt(pt, [a, b], constraints)
|
||||
assert converged
|
||||
|
||||
env = pt.get_env()
|
||||
for r in residuals:
|
||||
assert abs(r.eval(env)) < 1e-8
|
||||
|
||||
def test_point_in_plane_solve(self):
|
||||
"""Point in plane: body B's marker origin at z=0 plane.
|
||||
|
||||
Under-constrained (5 DOF remain), so verify residuals.
|
||||
"""
|
||||
pt = ParamTable()
|
||||
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
b = RigidBody("b", pt, (3, 4, 8), ID_QUAT)
|
||||
|
||||
constraints = [
|
||||
PointInPlaneConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
|
||||
]
|
||||
converged, residuals = _solve_with_pt(pt, [a, b], constraints)
|
||||
assert converged
|
||||
|
||||
env = pt.get_env()
|
||||
for r in residuals:
|
||||
assert abs(r.eval(env)) < 1e-8
|
||||
|
||||
def test_planar_solve(self):
|
||||
"""Planar: coplanar faces — parallel normals + point in plane."""
|
||||
pt = ParamTable()
|
||||
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
# Start B tilted and displaced
|
||||
b = RigidBody("b", pt, (3, 4, 8), ID_QUAT)
|
||||
|
||||
constraints = [PlanarConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
|
||||
converged, _ = _solve_with_pt(pt, [a, b], constraints)
|
||||
assert converged
|
||||
|
||||
env = pt.get_env()
|
||||
pos = b.extract_position(env)
|
||||
# Z must be zero (in plane), X and Y free
|
||||
assert abs(pos[2]) < 1e-8
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Multi-body integration tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestFourBarLinkage:
|
||||
"""Four-bar linkage: 4 bodies, 4 revolute joints.
|
||||
|
||||
In 3D with Z-axis revolutes, this yields 2 DOF: the expected planar
|
||||
motion plus an out-of-plane fold. A truly planar mechanism would
|
||||
add Planar constraints on each link to eliminate the fold DOF.
|
||||
"""
|
||||
|
||||
def test_four_bar_dof(self):
|
||||
"""Four-bar linkage in 3D has 2 DOF (planar + fold)."""
|
||||
pt = ParamTable()
|
||||
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
link1 = RigidBody("l1", pt, (2, 0, 0), ID_QUAT)
|
||||
link2 = RigidBody("l2", pt, (5, 3, 0), ID_QUAT)
|
||||
link3 = RigidBody("l3", pt, (8, 0, 0), ID_QUAT)
|
||||
|
||||
constraints = [
|
||||
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, link1, (0, 0, 0), ID_QUAT),
|
||||
RevoluteConstraint(link1, (4, 0, 0), ID_QUAT, link2, (0, 0, 0), ID_QUAT),
|
||||
RevoluteConstraint(link2, (6, 0, 0), ID_QUAT, link3, (0, 0, 0), ID_QUAT),
|
||||
RevoluteConstraint(link3, (4, 0, 0), ID_QUAT, ground, (10, 0, 0), ID_QUAT),
|
||||
]
|
||||
|
||||
bodies = [ground, link1, link2, link3]
|
||||
dof = _dof(pt, bodies, constraints)
|
||||
assert dof == 2
|
||||
|
||||
def test_four_bar_solves(self):
|
||||
"""Four-bar linkage converges from displaced initial conditions."""
|
||||
pt = ParamTable()
|
||||
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
# Initial positions slightly displaced from valid config
|
||||
link1 = RigidBody("l1", pt, (2, 1, 0), ID_QUAT)
|
||||
link2 = RigidBody("l2", pt, (5, 4, 0), ID_QUAT)
|
||||
link3 = RigidBody("l3", pt, (8, 1, 0), ID_QUAT)
|
||||
|
||||
constraints = [
|
||||
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, link1, (0, 0, 0), ID_QUAT),
|
||||
RevoluteConstraint(link1, (4, 0, 0), ID_QUAT, link2, (0, 0, 0), ID_QUAT),
|
||||
RevoluteConstraint(link2, (6, 0, 0), ID_QUAT, link3, (0, 0, 0), ID_QUAT),
|
||||
RevoluteConstraint(link3, (4, 0, 0), ID_QUAT, ground, (10, 0, 0), ID_QUAT),
|
||||
]
|
||||
|
||||
bodies = [ground, link1, link2, link3]
|
||||
converged, residuals = _solve_with_pt(pt, bodies, constraints)
|
||||
assert converged
|
||||
|
||||
# Verify all revolute constraints are satisfied
|
||||
env = pt.get_env()
|
||||
for r in residuals:
|
||||
assert abs(r.eval(env)) < 1e-8
|
||||
|
||||
|
||||
class TestSliderCrank:
|
||||
"""Slider-crank mechanism: crank + connecting rod + piston.
|
||||
|
||||
ground --[Revolute]-- crank --[Revolute]-- rod --[Revolute]-- piston --[Slider]-- ground
|
||||
|
||||
Using Slider (not Cylindrical) for the piston to also lock rotation,
|
||||
making it a true prismatic joint. In 3D, out-of-plane folding adds
|
||||
extra DOF beyond the planar 1-DOF.
|
||||
|
||||
3 free bodies × 6 = 18 DOF
|
||||
Revolute(5) + Revolute(5) + Revolute(5) + Slider(5) = 20
|
||||
But many constraints share bodies, so effective rank < 20.
|
||||
In 3D: 3 DOF (planar crank + 2 fold modes).
|
||||
"""
|
||||
|
||||
def test_slider_crank_dof(self):
|
||||
pt = ParamTable()
|
||||
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
crank = RigidBody("crank", pt, (1, 0, 0), ID_QUAT)
|
||||
rod = RigidBody("rod", pt, (3, 0, 0), ID_QUAT)
|
||||
piston = RigidBody("piston", pt, (5, 0, 0), ID_QUAT)
|
||||
|
||||
constraints = [
|
||||
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, crank, (0, 0, 0), ID_QUAT),
|
||||
RevoluteConstraint(crank, (2, 0, 0), ID_QUAT, rod, (0, 0, 0), ID_QUAT),
|
||||
RevoluteConstraint(rod, (4, 0, 0), ID_QUAT, piston, (0, 0, 0), ID_QUAT),
|
||||
SliderConstraint(piston, (0, 0, 0), ROT_90Y, ground, (0, 0, 0), ROT_90Y),
|
||||
]
|
||||
|
||||
bodies = [ground, crank, rod, piston]
|
||||
dof = _dof(pt, bodies, constraints)
|
||||
# With full 3-component cross products, the redundant constraint rows
|
||||
# eliminate the out-of-plane fold modes, giving the correct 1 DOF
|
||||
# (crank rotation only).
|
||||
assert dof == 1
|
||||
|
||||
def test_slider_crank_solves(self):
|
||||
"""Slider-crank converges from displaced state."""
|
||||
pt = ParamTable()
|
||||
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
crank = RigidBody("crank", pt, (1, 0.5, 0), ID_QUAT)
|
||||
rod = RigidBody("rod", pt, (3, 1, 0), ID_QUAT)
|
||||
piston = RigidBody("piston", pt, (5, 0.5, 0), ID_QUAT)
|
||||
|
||||
constraints = [
|
||||
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, crank, (0, 0, 0), ID_QUAT),
|
||||
RevoluteConstraint(crank, (2, 0, 0), ID_QUAT, rod, (0, 0, 0), ID_QUAT),
|
||||
RevoluteConstraint(rod, (4, 0, 0), ID_QUAT, piston, (0, 0, 0), ID_QUAT),
|
||||
SliderConstraint(piston, (0, 0, 0), ROT_90Y, ground, (0, 0, 0), ROT_90Y),
|
||||
]
|
||||
|
||||
bodies = [ground, crank, rod, piston]
|
||||
converged, residuals = _solve_with_pt(pt, bodies, constraints)
|
||||
assert converged
|
||||
|
||||
env = pt.get_env()
|
||||
for r in residuals:
|
||||
assert abs(r.eval(env)) < 1e-8
|
||||
|
||||
|
||||
class TestRevoluteChain:
|
||||
"""Chain of revolute joints: ground → body1 → body2.
|
||||
|
||||
Each revolute removes 5 DOF. Two free bodies = 12 DOF.
|
||||
2 revolutes = 10 constraints + 2 quat norms = 12.
|
||||
Expected: 2 DOF (one rotation per hinge).
|
||||
"""
|
||||
|
||||
def test_chain_dof(self):
|
||||
pt = ParamTable()
|
||||
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
b1 = RigidBody("b1", pt, (3, 0, 0), ID_QUAT)
|
||||
b2 = RigidBody("b2", pt, (6, 0, 0), ID_QUAT)
|
||||
|
||||
constraints = [
|
||||
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT),
|
||||
RevoluteConstraint(b1, (3, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT),
|
||||
]
|
||||
|
||||
assert _dof(pt, [ground, b1, b2], constraints) == 2
|
||||
|
||||
def test_chain_solves(self):
|
||||
pt = ParamTable()
|
||||
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
b1 = RigidBody("b1", pt, (3, 2, 0), ID_QUAT)
|
||||
b2 = RigidBody("b2", pt, (6, 3, 0), ID_QUAT)
|
||||
|
||||
constraints = [
|
||||
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT),
|
||||
RevoluteConstraint(b1, (3, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT),
|
||||
]
|
||||
|
||||
converged, residuals = _solve_with_pt(pt, [ground, b1, b2], constraints)
|
||||
assert converged
|
||||
|
||||
env = pt.get_env()
|
||||
# b1 origin at ground hinge point (0,0,0)
|
||||
pos1 = b1.extract_position(env)
|
||||
assert abs(pos1[0]) < 1e-8
|
||||
assert abs(pos1[1]) < 1e-8
|
||||
assert abs(pos1[2]) < 1e-8
|
||||
|
||||
|
||||
class TestSliderOnRail:
|
||||
"""Slider constraint: body translates along ground Z-axis only.
|
||||
|
||||
1 free body, 1 slider = 6 - 5 = 1 DOF.
|
||||
"""
|
||||
|
||||
def test_slider_on_rail(self):
|
||||
pt = ParamTable()
|
||||
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
block = RigidBody("block", pt, (3, 4, 5), ID_QUAT)
|
||||
|
||||
constraints = [
|
||||
SliderConstraint(ground, (0, 0, 0), ID_QUAT, block, (0, 0, 0), ID_QUAT)
|
||||
]
|
||||
|
||||
converged, _ = _solve_with_pt(pt, [ground, block], constraints)
|
||||
assert converged
|
||||
|
||||
env = pt.get_env()
|
||||
pos = block.extract_position(env)
|
||||
# X, Y must be zero; Z is free
|
||||
assert abs(pos[0]) < 1e-8
|
||||
assert abs(pos[1]) < 1e-8
|
||||
# Z should remain near initial value (minimum-norm solution)
|
||||
|
||||
# Check orientation unchanged (no twist)
|
||||
quat = block.extract_quaternion(env)
|
||||
assert abs(quat[0] - 1.0) < 1e-6
|
||||
assert abs(quat[1]) < 1e-6
|
||||
assert abs(quat[2]) < 1e-6
|
||||
assert abs(quat[3]) < 1e-6
|
||||
|
||||
|
||||
class TestPlanarOnTable:
|
||||
"""Planar constraint: body slides on XY plane.
|
||||
|
||||
1 free body, 1 planar = 6 - 3 = 3 DOF.
|
||||
"""
|
||||
|
||||
def test_planar_on_table(self):
|
||||
pt = ParamTable()
|
||||
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
block = RigidBody("block", pt, (3, 4, 5), ID_QUAT)
|
||||
|
||||
constraints = [
|
||||
PlanarConstraint(ground, (0, 0, 0), ID_QUAT, block, (0, 0, 0), ID_QUAT)
|
||||
]
|
||||
|
||||
converged, _ = _solve_with_pt(pt, [ground, block], constraints)
|
||||
assert converged
|
||||
|
||||
env = pt.get_env()
|
||||
pos = block.extract_position(env)
|
||||
# Z must be zero, X and Y are free
|
||||
assert abs(pos[2]) < 1e-8
|
||||
|
||||
|
||||
class TestPlanarWithOffset:
|
||||
"""Planar with offset: body floats at z=3 above ground."""
|
||||
|
||||
def test_planar_offset(self):
|
||||
pt = ParamTable()
|
||||
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
block = RigidBody("block", pt, (1, 2, 5), ID_QUAT)
|
||||
|
||||
# PlanarConstraint residual: (p_i - p_j) . z_j - offset = 0
|
||||
# body_i=block, body_j=ground: (block_z - 0) * 1 - offset = 0
|
||||
# For block at z=3: offset = 3
|
||||
constraints = [
|
||||
PlanarConstraint(
|
||||
block, (0, 0, 0), ID_QUAT, ground, (0, 0, 0), ID_QUAT, offset=3.0
|
||||
)
|
||||
]
|
||||
|
||||
converged, _ = _solve_with_pt(pt, [ground, block], constraints)
|
||||
assert converged
|
||||
|
||||
env = pt.get_env()
|
||||
pos = block.extract_position(env)
|
||||
assert abs(pos[2] - 3.0) < 1e-8
|
||||
|
||||
|
||||
class TestMixedConstraints:
|
||||
"""System with mixed constraint types."""
|
||||
|
||||
def test_revolute_plus_parallel(self):
|
||||
"""Two free bodies: revolute between ground and b1, parallel between b1 and b2.
|
||||
|
||||
b1: 6 DOF - 5 (revolute) = 1 DOF
|
||||
b2: 6 DOF - 2 (parallel) = 4 DOF
|
||||
Total: 5 DOF
|
||||
"""
|
||||
pt = ParamTable()
|
||||
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
b1 = RigidBody("b1", pt, (0, 0, 0), ID_QUAT)
|
||||
b2 = RigidBody("b2", pt, (5, 0, 0), ID_QUAT)
|
||||
|
||||
constraints = [
|
||||
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT),
|
||||
ParallelConstraint(b1, ID_QUAT, b2, ID_QUAT),
|
||||
]
|
||||
|
||||
assert _dof(pt, [ground, b1, b2], constraints) == 5
|
||||
|
||||
def test_coincident_plus_perpendicular(self):
|
||||
"""Coincident + perpendicular = ball + 1 angle constraint.
|
||||
|
||||
6 - 3 (coincident) - 1 (perpendicular) = 2 DOF.
|
||||
"""
|
||||
pt = ParamTable()
|
||||
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
|
||||
b = RigidBody("b", pt, (0, 0, 0), ROT_90X)
|
||||
|
||||
constraints = [
|
||||
CoincidentConstraint(ground, (0, 0, 0), b, (0, 0, 0)),
|
||||
PerpendicularConstraint(ground, ID_QUAT, b, ID_QUAT),
|
||||
]
|
||||
|
||||
assert _dof(pt, [ground, b], constraints) == 2
|
||||
@@ -65,3 +65,84 @@ class TestParamTable:
|
||||
pt.add("b", 0.0, fixed=True)
|
||||
pt.add("c", 0.0)
|
||||
assert pt.n_free() == 2
|
||||
|
||||
def test_unfix(self):
|
||||
pt = ParamTable()
|
||||
pt.add("a", 1.0)
|
||||
pt.add("b", 2.0)
|
||||
pt.fix("a")
|
||||
assert pt.is_fixed("a")
|
||||
assert "a" not in pt.free_names()
|
||||
|
||||
pt.unfix("a")
|
||||
assert not pt.is_fixed("a")
|
||||
assert "a" in pt.free_names()
|
||||
assert pt.n_free() == 2
|
||||
|
||||
def test_fix_unfix_roundtrip(self):
|
||||
"""Fix then unfix preserves value and makes param free again."""
|
||||
pt = ParamTable()
|
||||
pt.add("x", 5.0)
|
||||
pt.add("y", 3.0)
|
||||
pt.fix("x")
|
||||
pt.set_value("x", 10.0)
|
||||
pt.unfix("x")
|
||||
assert pt.get_value("x") == 10.0
|
||||
assert "x" in pt.free_names()
|
||||
# x moves to end of free list
|
||||
assert pt.free_names() == ["y", "x"]
|
||||
|
||||
def test_unfix_noop_if_already_free(self):
|
||||
"""Unfixing a free parameter is a no-op."""
|
||||
pt = ParamTable()
|
||||
pt.add("a", 1.0)
|
||||
pt.unfix("a")
|
||||
assert pt.free_names() == ["a"]
|
||||
assert pt.n_free() == 1
|
||||
|
||||
def test_snapshot_restore_roundtrip(self):
|
||||
"""Snapshot captures values; restore brings them back."""
|
||||
pt = ParamTable()
|
||||
pt.add("x", 1.0)
|
||||
pt.add("y", 2.0)
|
||||
pt.add("z", 3.0, fixed=True)
|
||||
snap = pt.snapshot()
|
||||
pt.set_value("x", 99.0)
|
||||
pt.set_value("y", 88.0)
|
||||
pt.set_value("z", 77.0)
|
||||
pt.restore(snap)
|
||||
assert pt.get_value("x") == 1.0
|
||||
assert pt.get_value("y") == 2.0
|
||||
assert pt.get_value("z") == 3.0
|
||||
|
||||
def test_snapshot_is_independent_copy(self):
|
||||
"""Mutating snapshot dict does not affect the table."""
|
||||
pt = ParamTable()
|
||||
pt.add("a", 5.0)
|
||||
snap = pt.snapshot()
|
||||
snap["a"] = 999.0
|
||||
assert pt.get_value("a") == 5.0
|
||||
|
||||
def test_movement_cost_no_weights(self):
|
||||
"""Movement cost is sum of squared displacements for free params."""
|
||||
pt = ParamTable()
|
||||
pt.add("x", 0.0)
|
||||
pt.add("y", 0.0)
|
||||
pt.add("z", 0.0, fixed=True)
|
||||
snap = pt.snapshot()
|
||||
pt.set_value("x", 3.0)
|
||||
pt.set_value("y", 4.0)
|
||||
pt.set_value("z", 100.0) # fixed — ignored
|
||||
assert pt.movement_cost(snap) == pytest.approx(25.0)
|
||||
|
||||
def test_movement_cost_with_weights(self):
|
||||
"""Weighted movement cost scales each displacement."""
|
||||
pt = ParamTable()
|
||||
pt.add("a", 0.0)
|
||||
pt.add("b", 0.0)
|
||||
snap = pt.snapshot()
|
||||
pt.set_value("a", 1.0)
|
||||
pt.set_value("b", 1.0)
|
||||
weights = {"a": 4.0, "b": 9.0}
|
||||
# cost = 1^2*4 + 1^2*9 = 13
|
||||
assert pt.movement_cost(snap, weights) == pytest.approx(13.0)
|
||||
|
||||
384
tests/test_preference.py
Normal file
384
tests/test_preference.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""Tests for solution preference: half-space tracking and corrections."""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from kindred_solver.constraints import (
|
||||
AngleConstraint,
|
||||
DistancePointPointConstraint,
|
||||
ParallelConstraint,
|
||||
PerpendicularConstraint,
|
||||
)
|
||||
from kindred_solver.entities import RigidBody
|
||||
from kindred_solver.newton import newton_solve
|
||||
from kindred_solver.params import ParamTable
|
||||
from kindred_solver.preference import (
|
||||
apply_half_space_correction,
|
||||
compute_half_spaces,
|
||||
)
|
||||
|
||||
|
||||
def _make_two_bodies(
|
||||
params,
|
||||
pos_a=(0, 0, 0),
|
||||
pos_b=(5, 0, 0),
|
||||
quat_a=(1, 0, 0, 0),
|
||||
quat_b=(1, 0, 0, 0),
|
||||
ground_a=True,
|
||||
ground_b=False,
|
||||
):
|
||||
"""Create two bodies with given positions/orientations."""
|
||||
body_a = RigidBody(
|
||||
"a", params, position=pos_a, quaternion=quat_a, grounded=ground_a
|
||||
)
|
||||
body_b = RigidBody(
|
||||
"b", params, position=pos_b, quaternion=quat_b, grounded=ground_b
|
||||
)
|
||||
return body_a, body_b
|
||||
|
||||
|
||||
class TestDistanceHalfSpace:
|
||||
"""Half-space tracking for DistancePointPoint constraint."""
|
||||
|
||||
def test_positive_x_stays_positive(self):
|
||||
"""Body starting at +X should stay at +X after solve."""
|
||||
params = ParamTable()
|
||||
body_a, body_b = _make_two_bodies(params, pos_b=(3, 0, 0))
|
||||
c = DistancePointPointConstraint(
|
||||
body_a,
|
||||
(0, 0, 0),
|
||||
body_b,
|
||||
(0, 0, 0),
|
||||
distance=5.0,
|
||||
)
|
||||
hs = compute_half_spaces([c], [0], params)
|
||||
assert len(hs) == 1
|
||||
|
||||
# Solve with half-space correction
|
||||
residuals = c.residuals()
|
||||
residuals.append(body_b.quat_norm_residual())
|
||||
quat_groups = [body_b.quat_param_names()]
|
||||
|
||||
def post_step(p):
|
||||
apply_half_space_correction(p, hs)
|
||||
|
||||
converged = newton_solve(
|
||||
residuals,
|
||||
params,
|
||||
quat_groups=quat_groups,
|
||||
post_step=post_step,
|
||||
)
|
||||
assert converged
|
||||
env = params.get_env()
|
||||
# Body b should be at +X (x > 0), not -X
|
||||
bx = env["b/tx"]
|
||||
assert bx > 0, f"Expected positive X, got {bx}"
|
||||
# Distance should be 5
|
||||
dist = math.sqrt(bx**2 + env["b/ty"] ** 2 + env["b/tz"] ** 2)
|
||||
assert dist == pytest.approx(5.0, abs=1e-8)
|
||||
|
||||
def test_negative_x_stays_negative(self):
|
||||
"""Body starting at -X should stay at -X after solve."""
|
||||
params = ParamTable()
|
||||
body_a, body_b = _make_two_bodies(params, pos_b=(-3, 0, 0))
|
||||
c = DistancePointPointConstraint(
|
||||
body_a,
|
||||
(0, 0, 0),
|
||||
body_b,
|
||||
(0, 0, 0),
|
||||
distance=5.0,
|
||||
)
|
||||
hs = compute_half_spaces([c], [0], params)
|
||||
assert len(hs) == 1
|
||||
|
||||
residuals = c.residuals()
|
||||
residuals.append(body_b.quat_norm_residual())
|
||||
quat_groups = [body_b.quat_param_names()]
|
||||
|
||||
def post_step(p):
|
||||
apply_half_space_correction(p, hs)
|
||||
|
||||
converged = newton_solve(
|
||||
residuals,
|
||||
params,
|
||||
quat_groups=quat_groups,
|
||||
post_step=post_step,
|
||||
)
|
||||
assert converged
|
||||
env = params.get_env()
|
||||
bx = env["b/tx"]
|
||||
assert bx < 0, f"Expected negative X, got {bx}"
|
||||
|
||||
def test_zero_distance_no_halfspace(self):
|
||||
"""Zero distance constraint has no branch ambiguity."""
|
||||
params = ParamTable()
|
||||
body_a, body_b = _make_two_bodies(params, pos_b=(3, 0, 0))
|
||||
c = DistancePointPointConstraint(
|
||||
body_a,
|
||||
(0, 0, 0),
|
||||
body_b,
|
||||
(0, 0, 0),
|
||||
distance=0.0,
|
||||
)
|
||||
hs = compute_half_spaces([c], [0], params)
|
||||
assert len(hs) == 0
|
||||
|
||||
|
||||
class TestParallelHalfSpace:
|
||||
"""Half-space tracking for Parallel constraint."""
|
||||
|
||||
def test_same_direction_tracked(self):
|
||||
"""Same-direction parallel: positive reference sign."""
|
||||
params = ParamTable()
|
||||
body_a, body_b = _make_two_bodies(params)
|
||||
c = ParallelConstraint(body_a, (1, 0, 0, 0), body_b, (1, 0, 0, 0))
|
||||
hs = compute_half_spaces([c], [0], params)
|
||||
assert len(hs) == 1
|
||||
assert hs[0].reference_sign == 1.0
|
||||
|
||||
def test_opposite_direction_tracked(self):
|
||||
"""Opposite-direction parallel: negative reference sign."""
|
||||
params = ParamTable()
|
||||
# Rotate body_b by 180 degrees about X: Z-axis flips
|
||||
q_flip = (0, 1, 0, 0) # 180 deg about X
|
||||
body_a, body_b = _make_two_bodies(params, quat_b=q_flip)
|
||||
c = ParallelConstraint(body_a, (1, 0, 0, 0), body_b, (1, 0, 0, 0))
|
||||
hs = compute_half_spaces([c], [0], params)
|
||||
assert len(hs) == 1
|
||||
assert hs[0].reference_sign == -1.0
|
||||
|
||||
|
||||
class TestAngleHalfSpace:
|
||||
"""Half-space tracking for Angle constraint."""
|
||||
|
||||
def test_90_degree_angle(self):
|
||||
"""90-degree angle constraint creates a half-space."""
|
||||
params = ParamTable()
|
||||
# Rotate body_b by 90 degrees about X
|
||||
q_90x = (math.cos(math.pi / 4), math.sin(math.pi / 4), 0, 0)
|
||||
body_a, body_b = _make_two_bodies(params, quat_b=q_90x)
|
||||
c = AngleConstraint(
|
||||
body_a,
|
||||
(1, 0, 0, 0),
|
||||
body_b,
|
||||
(1, 0, 0, 0),
|
||||
angle=math.pi / 2,
|
||||
)
|
||||
hs = compute_half_spaces([c], [0], params)
|
||||
assert len(hs) == 1
|
||||
|
||||
def test_zero_angle_no_halfspace(self):
|
||||
"""0-degree angle has no branch ambiguity."""
|
||||
params = ParamTable()
|
||||
body_a, body_b = _make_two_bodies(params)
|
||||
c = AngleConstraint(
|
||||
body_a,
|
||||
(1, 0, 0, 0),
|
||||
body_b,
|
||||
(1, 0, 0, 0),
|
||||
angle=0.0,
|
||||
)
|
||||
hs = compute_half_spaces([c], [0], params)
|
||||
assert len(hs) == 0
|
||||
|
||||
def test_180_angle_no_halfspace(self):
|
||||
"""180-degree angle has no branch ambiguity."""
|
||||
params = ParamTable()
|
||||
body_a, body_b = _make_two_bodies(params)
|
||||
c = AngleConstraint(
|
||||
body_a,
|
||||
(1, 0, 0, 0),
|
||||
body_b,
|
||||
(1, 0, 0, 0),
|
||||
angle=math.pi,
|
||||
)
|
||||
hs = compute_half_spaces([c], [0], params)
|
||||
assert len(hs) == 0
|
||||
|
||||
|
||||
class TestPerpendicularHalfSpace:
|
||||
"""Half-space tracking for Perpendicular constraint."""
|
||||
|
||||
def test_perpendicular_tracked(self):
|
||||
"""Perpendicular constraint creates a half-space."""
|
||||
params = ParamTable()
|
||||
# Rotate body_b by 90 degrees about X
|
||||
q_90x = (math.cos(math.pi / 4), math.sin(math.pi / 4), 0, 0)
|
||||
body_a, body_b = _make_two_bodies(params, quat_b=q_90x)
|
||||
c = PerpendicularConstraint(
|
||||
body_a,
|
||||
(1, 0, 0, 0),
|
||||
body_b,
|
||||
(1, 0, 0, 0),
|
||||
)
|
||||
hs = compute_half_spaces([c], [0], params)
|
||||
assert len(hs) == 1
|
||||
|
||||
|
||||
class TestNewtonPostStep:
|
||||
"""Verify Newton post_step callback works correctly."""
|
||||
|
||||
def test_callback_fires(self):
|
||||
"""post_step callback is invoked during Newton iterations."""
|
||||
params = ParamTable()
|
||||
x = params.add("x", 2.0)
|
||||
from kindred_solver.expr import Const
|
||||
|
||||
residuals = [x - Const(5.0)]
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def counter(p):
|
||||
call_count[0] += 1
|
||||
|
||||
converged = newton_solve(residuals, params, post_step=counter)
|
||||
assert converged
|
||||
assert call_count[0] >= 1
|
||||
|
||||
def test_callback_does_not_break_convergence(self):
|
||||
"""A no-op callback doesn't prevent convergence."""
|
||||
params = ParamTable()
|
||||
x = params.add("x", 1.0)
|
||||
y = params.add("y", 1.0)
|
||||
from kindred_solver.expr import Const
|
||||
|
||||
residuals = [x - Const(3.0), y - Const(7.0)]
|
||||
|
||||
converged = newton_solve(residuals, params, post_step=lambda p: None)
|
||||
assert converged
|
||||
assert params.get_value("x") == pytest.approx(3.0)
|
||||
assert params.get_value("y") == pytest.approx(7.0)
|
||||
|
||||
|
||||
class TestMixedHalfSpaces:
|
||||
"""Multiple branching constraints in one system."""
|
||||
|
||||
def test_multiple_constraints(self):
|
||||
"""compute_half_spaces handles mixed constraint types."""
|
||||
params = ParamTable()
|
||||
body_a, body_b = _make_two_bodies(params, pos_b=(5, 0, 0))
|
||||
|
||||
dist_c = DistancePointPointConstraint(
|
||||
body_a,
|
||||
(0, 0, 0),
|
||||
body_b,
|
||||
(0, 0, 0),
|
||||
distance=5.0,
|
||||
)
|
||||
par_c = ParallelConstraint(body_a, (1, 0, 0, 0), body_b, (1, 0, 0, 0))
|
||||
|
||||
hs = compute_half_spaces([dist_c, par_c], [0, 1], params)
|
||||
assert len(hs) == 2
|
||||
|
||||
|
||||
class TestBuildWeightVector:
|
||||
"""Weight vector construction."""
|
||||
|
||||
def test_translation_weight_one(self):
|
||||
"""Translation params get weight 1.0."""
|
||||
from kindred_solver.preference import build_weight_vector
|
||||
|
||||
params = ParamTable()
|
||||
params.add("body/tx", 0.0)
|
||||
params.add("body/ty", 0.0)
|
||||
params.add("body/tz", 0.0)
|
||||
w = build_weight_vector(params)
|
||||
np.testing.assert_array_equal(w, [1.0, 1.0, 1.0])
|
||||
|
||||
def test_quaternion_weight_high(self):
|
||||
"""Quaternion params get QUAT_WEIGHT."""
|
||||
from kindred_solver.preference import QUAT_WEIGHT, build_weight_vector
|
||||
|
||||
params = ParamTable()
|
||||
params.add("body/qw", 1.0)
|
||||
params.add("body/qx", 0.0)
|
||||
params.add("body/qy", 0.0)
|
||||
params.add("body/qz", 0.0)
|
||||
w = build_weight_vector(params)
|
||||
np.testing.assert_array_equal(w, [QUAT_WEIGHT] * 4)
|
||||
|
||||
def test_mixed_params(self):
|
||||
"""Mixed translation and quaternion params get correct weights."""
|
||||
from kindred_solver.preference import QUAT_WEIGHT, build_weight_vector
|
||||
|
||||
params = ParamTable()
|
||||
params.add("b/tx", 0.0)
|
||||
params.add("b/qw", 1.0)
|
||||
params.add("b/ty", 0.0)
|
||||
params.add("b/qx", 0.0)
|
||||
w = build_weight_vector(params)
|
||||
assert w[0] == pytest.approx(1.0)
|
||||
assert w[1] == pytest.approx(QUAT_WEIGHT)
|
||||
assert w[2] == pytest.approx(1.0)
|
||||
assert w[3] == pytest.approx(QUAT_WEIGHT)
|
||||
|
||||
def test_fixed_params_excluded(self):
|
||||
"""Fixed params are not in free list, so not in weight vector."""
|
||||
from kindred_solver.preference import build_weight_vector
|
||||
|
||||
params = ParamTable()
|
||||
params.add("b/tx", 0.0, fixed=True)
|
||||
params.add("b/ty", 0.0)
|
||||
w = build_weight_vector(params)
|
||||
assert len(w) == 1
|
||||
assert w[0] == pytest.approx(1.0)
|
||||
|
||||
|
||||
class TestWeightedNewton:
|
||||
"""Weighted minimum-norm Newton solve."""
|
||||
|
||||
def test_well_constrained_same_result(self):
|
||||
"""Weighted and unweighted produce identical results for unique solution."""
|
||||
from kindred_solver.expr import Const
|
||||
|
||||
# Fully determined system: x = 3, y = 7
|
||||
params1 = ParamTable()
|
||||
x1 = params1.add("x", 1.0)
|
||||
y1 = params1.add("y", 1.0)
|
||||
r1 = [x1 - Const(3.0), y1 - Const(7.0)]
|
||||
|
||||
params2 = ParamTable()
|
||||
x2 = params2.add("x", 1.0)
|
||||
y2 = params2.add("y", 1.0)
|
||||
r2 = [x2 - Const(3.0), y2 - Const(7.0)]
|
||||
|
||||
newton_solve(r1, params1)
|
||||
newton_solve(r2, params2, weight_vector=np.array([1.0, 100.0]))
|
||||
|
||||
assert params1.get_value("x") == pytest.approx(
|
||||
params2.get_value("x"), abs=1e-10
|
||||
)
|
||||
assert params1.get_value("y") == pytest.approx(
|
||||
params2.get_value("y"), abs=1e-10
|
||||
)
|
||||
|
||||
def test_underconstrained_prefers_low_weight(self):
|
||||
"""Under-constrained: weighted solve moves high-weight params less."""
|
||||
from kindred_solver.expr import Const
|
||||
|
||||
# 1 equation, 2 unknowns: x + y = 10 (from x=0, y=0)
|
||||
params_unw = ParamTable()
|
||||
xu = params_unw.add("x", 0.0)
|
||||
yu = params_unw.add("y", 0.0)
|
||||
ru = [xu + yu - Const(10.0)]
|
||||
|
||||
params_w = ParamTable()
|
||||
xw = params_w.add("x", 0.0)
|
||||
yw = params_w.add("y", 0.0)
|
||||
rw = [xw + yw - Const(10.0)]
|
||||
|
||||
# Unweighted: lstsq gives equal movement
|
||||
newton_solve(ru, params_unw)
|
||||
|
||||
# Weighted: y is 100x more expensive to move
|
||||
newton_solve(rw, params_w, weight_vector=np.array([1.0, 100.0]))
|
||||
|
||||
# Both should satisfy x + y = 10
|
||||
assert params_unw.get_value("x") + params_unw.get_value("y") == pytest.approx(
|
||||
10.0
|
||||
)
|
||||
assert params_w.get_value("x") + params_w.get_value("y") == pytest.approx(10.0)
|
||||
|
||||
# Weighted solve should move y less than x
|
||||
assert abs(params_w.get_value("y")) < abs(params_w.get_value("x"))
|
||||
Reference in New Issue
Block a user