- Add per-entity DOF analysis via Jacobian SVD (diagnostics.py) - Add overconstrained detection: redundant vs conflicting constraints - Add half-space tracking to preserve configuration branch (preference.py) - Add minimum-movement weighting for least-squares solve - Extend BFGS fallback with weight vector and quaternion renormalization - Add snapshot/restore and env accessor to ParamTable - Fix DistancePointPointConstraint sign for half-space tracking
149 lines
4.3 KiB
Python
149 lines
4.3 KiB
Python
"""Tests for the parameter table."""
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from kindred_solver.expr import Var
|
|
from kindred_solver.params import ParamTable
|
|
|
|
|
|
class TestParamTable:
|
|
def test_add_and_get(self):
|
|
pt = ParamTable()
|
|
v = pt.add("x", 3.0)
|
|
assert isinstance(v, Var)
|
|
assert v.name == "x"
|
|
assert pt.get_value("x") == 3.0
|
|
|
|
def test_duplicate_raises(self):
|
|
pt = ParamTable()
|
|
pt.add("x")
|
|
with pytest.raises(ValueError, match="Duplicate"):
|
|
pt.add("x")
|
|
|
|
def test_fixed(self):
|
|
pt = ParamTable()
|
|
pt.add("x", 1.0, fixed=True)
|
|
pt.add("y", 2.0, fixed=False)
|
|
assert pt.is_fixed("x")
|
|
assert not pt.is_fixed("y")
|
|
assert pt.free_names() == ["y"]
|
|
|
|
def test_fix(self):
|
|
pt = ParamTable()
|
|
pt.add("x", 1.0)
|
|
assert "x" in pt.free_names()
|
|
pt.fix("x")
|
|
assert "x" not in pt.free_names()
|
|
assert pt.is_fixed("x")
|
|
|
|
def test_env(self):
|
|
pt = ParamTable()
|
|
pt.add("a", 1.0)
|
|
pt.add("b", 2.0, fixed=True)
|
|
env = pt.get_env()
|
|
assert env == {"a": 1.0, "b": 2.0}
|
|
|
|
def test_free_vector(self):
|
|
pt = ParamTable()
|
|
pt.add("a", 1.0)
|
|
pt.add("b", 2.0, fixed=True)
|
|
pt.add("c", 3.0)
|
|
vec = pt.get_free_vector()
|
|
np.testing.assert_array_equal(vec, [1.0, 3.0])
|
|
|
|
def test_set_free_vector(self):
|
|
pt = ParamTable()
|
|
pt.add("a", 0.0)
|
|
pt.add("b", 0.0)
|
|
pt.set_free_vector(np.array([5.0, 7.0]))
|
|
assert pt.get_value("a") == 5.0
|
|
assert pt.get_value("b") == 7.0
|
|
|
|
def test_n_free(self):
|
|
pt = ParamTable()
|
|
pt.add("a", 0.0)
|
|
pt.add("b", 0.0, fixed=True)
|
|
pt.add("c", 0.0)
|
|
assert pt.n_free() == 2
|
|
|
|
def test_unfix(self):
|
|
pt = ParamTable()
|
|
pt.add("a", 1.0)
|
|
pt.add("b", 2.0)
|
|
pt.fix("a")
|
|
assert pt.is_fixed("a")
|
|
assert "a" not in pt.free_names()
|
|
|
|
pt.unfix("a")
|
|
assert not pt.is_fixed("a")
|
|
assert "a" in pt.free_names()
|
|
assert pt.n_free() == 2
|
|
|
|
def test_fix_unfix_roundtrip(self):
|
|
"""Fix then unfix preserves value and makes param free again."""
|
|
pt = ParamTable()
|
|
pt.add("x", 5.0)
|
|
pt.add("y", 3.0)
|
|
pt.fix("x")
|
|
pt.set_value("x", 10.0)
|
|
pt.unfix("x")
|
|
assert pt.get_value("x") == 10.0
|
|
assert "x" in pt.free_names()
|
|
# x moves to end of free list
|
|
assert pt.free_names() == ["y", "x"]
|
|
|
|
def test_unfix_noop_if_already_free(self):
|
|
"""Unfixing a free parameter is a no-op."""
|
|
pt = ParamTable()
|
|
pt.add("a", 1.0)
|
|
pt.unfix("a")
|
|
assert pt.free_names() == ["a"]
|
|
assert pt.n_free() == 1
|
|
|
|
def test_snapshot_restore_roundtrip(self):
|
|
"""Snapshot captures values; restore brings them back."""
|
|
pt = ParamTable()
|
|
pt.add("x", 1.0)
|
|
pt.add("y", 2.0)
|
|
pt.add("z", 3.0, fixed=True)
|
|
snap = pt.snapshot()
|
|
pt.set_value("x", 99.0)
|
|
pt.set_value("y", 88.0)
|
|
pt.set_value("z", 77.0)
|
|
pt.restore(snap)
|
|
assert pt.get_value("x") == 1.0
|
|
assert pt.get_value("y") == 2.0
|
|
assert pt.get_value("z") == 3.0
|
|
|
|
def test_snapshot_is_independent_copy(self):
|
|
"""Mutating snapshot dict does not affect the table."""
|
|
pt = ParamTable()
|
|
pt.add("a", 5.0)
|
|
snap = pt.snapshot()
|
|
snap["a"] = 999.0
|
|
assert pt.get_value("a") == 5.0
|
|
|
|
def test_movement_cost_no_weights(self):
|
|
"""Movement cost is sum of squared displacements for free params."""
|
|
pt = ParamTable()
|
|
pt.add("x", 0.0)
|
|
pt.add("y", 0.0)
|
|
pt.add("z", 0.0, fixed=True)
|
|
snap = pt.snapshot()
|
|
pt.set_value("x", 3.0)
|
|
pt.set_value("y", 4.0)
|
|
pt.set_value("z", 100.0) # fixed — ignored
|
|
assert pt.movement_cost(snap) == pytest.approx(25.0)
|
|
|
|
def test_movement_cost_with_weights(self):
|
|
"""Weighted movement cost scales each displacement."""
|
|
pt = ParamTable()
|
|
pt.add("a", 0.0)
|
|
pt.add("b", 0.0)
|
|
snap = pt.snapshot()
|
|
pt.set_value("a", 1.0)
|
|
pt.set_value("b", 1.0)
|
|
weights = {"a": 4.0, "b": 9.0}
|
|
# cost = 1^2*4 + 1^2*9 = 13
|
|
assert pt.movement_cost(snap, weights) == pytest.approx(13.0)
|