"""Tests for the parameter table.""" import numpy as np import pytest from kindred_solver.expr import Var from kindred_solver.params import ParamTable class TestParamTable: def test_add_and_get(self): pt = ParamTable() v = pt.add("x", 3.0) assert isinstance(v, Var) assert v.name == "x" assert pt.get_value("x") == 3.0 def test_duplicate_raises(self): pt = ParamTable() pt.add("x") with pytest.raises(ValueError, match="Duplicate"): pt.add("x") def test_fixed(self): pt = ParamTable() pt.add("x", 1.0, fixed=True) pt.add("y", 2.0, fixed=False) assert pt.is_fixed("x") assert not pt.is_fixed("y") assert pt.free_names() == ["y"] def test_fix(self): pt = ParamTable() pt.add("x", 1.0) assert "x" in pt.free_names() pt.fix("x") assert "x" not in pt.free_names() assert pt.is_fixed("x") def test_env(self): pt = ParamTable() pt.add("a", 1.0) pt.add("b", 2.0, fixed=True) env = pt.get_env() assert env == {"a": 1.0, "b": 2.0} def test_free_vector(self): pt = ParamTable() pt.add("a", 1.0) pt.add("b", 2.0, fixed=True) pt.add("c", 3.0) vec = pt.get_free_vector() np.testing.assert_array_equal(vec, [1.0, 3.0]) def test_set_free_vector(self): pt = ParamTable() pt.add("a", 0.0) pt.add("b", 0.0) pt.set_free_vector(np.array([5.0, 7.0])) assert pt.get_value("a") == 5.0 assert pt.get_value("b") == 7.0 def test_n_free(self): pt = ParamTable() pt.add("a", 0.0) pt.add("b", 0.0, fixed=True) pt.add("c", 0.0) assert pt.n_free() == 2 def test_unfix(self): pt = ParamTable() pt.add("a", 1.0) pt.add("b", 2.0) pt.fix("a") assert pt.is_fixed("a") assert "a" not in pt.free_names() pt.unfix("a") assert not pt.is_fixed("a") assert "a" in pt.free_names() assert pt.n_free() == 2 def test_fix_unfix_roundtrip(self): """Fix then unfix preserves value and makes param free again.""" pt = ParamTable() pt.add("x", 5.0) pt.add("y", 3.0) pt.fix("x") pt.set_value("x", 10.0) pt.unfix("x") assert pt.get_value("x") == 10.0 assert "x" in pt.free_names() # x moves to end of free list assert pt.free_names() == ["y", "x"] def test_unfix_noop_if_already_free(self): """Unfixing a free parameter is a no-op.""" pt = ParamTable() pt.add("a", 1.0) pt.unfix("a") assert pt.free_names() == ["a"] assert pt.n_free() == 1 def test_snapshot_restore_roundtrip(self): """Snapshot captures values; restore brings them back.""" pt = ParamTable() pt.add("x", 1.0) pt.add("y", 2.0) pt.add("z", 3.0, fixed=True) snap = pt.snapshot() pt.set_value("x", 99.0) pt.set_value("y", 88.0) pt.set_value("z", 77.0) pt.restore(snap) assert pt.get_value("x") == 1.0 assert pt.get_value("y") == 2.0 assert pt.get_value("z") == 3.0 def test_snapshot_is_independent_copy(self): """Mutating snapshot dict does not affect the table.""" pt = ParamTable() pt.add("a", 5.0) snap = pt.snapshot() snap["a"] = 999.0 assert pt.get_value("a") == 5.0 def test_movement_cost_no_weights(self): """Movement cost is sum of squared displacements for free params.""" pt = ParamTable() pt.add("x", 0.0) pt.add("y", 0.0) pt.add("z", 0.0, fixed=True) snap = pt.snapshot() pt.set_value("x", 3.0) pt.set_value("y", 4.0) pt.set_value("z", 100.0) # fixed — ignored assert pt.movement_cost(snap) == pytest.approx(25.0) def test_movement_cost_with_weights(self): """Weighted movement cost scales each displacement.""" pt = ParamTable() pt.add("a", 0.0) pt.add("b", 0.0) snap = pt.snapshot() pt.set_value("a", 1.0) pt.set_value("b", 1.0) weights = {"a": 4.0, "b": 9.0} # cost = 1^2*4 + 1^2*9 = 13 assert pt.movement_cost(snap, weights) == pytest.approx(13.0)