"""Tests for the parameter table.""" import numpy as np import pytest from kindred_solver.expr import Var from kindred_solver.params import ParamTable class TestParamTable: def test_add_and_get(self): pt = ParamTable() v = pt.add("x", 3.0) assert isinstance(v, Var) assert v.name == "x" assert pt.get_value("x") == 3.0 def test_duplicate_raises(self): pt = ParamTable() pt.add("x") with pytest.raises(ValueError, match="Duplicate"): pt.add("x") def test_fixed(self): pt = ParamTable() pt.add("x", 1.0, fixed=True) pt.add("y", 2.0, fixed=False) assert pt.is_fixed("x") assert not pt.is_fixed("y") assert pt.free_names() == ["y"] def test_fix(self): pt = ParamTable() pt.add("x", 1.0) assert "x" in pt.free_names() pt.fix("x") assert "x" not in pt.free_names() assert pt.is_fixed("x") def test_env(self): pt = ParamTable() pt.add("a", 1.0) pt.add("b", 2.0, fixed=True) env = pt.get_env() assert env == {"a": 1.0, "b": 2.0} def test_free_vector(self): pt = ParamTable() pt.add("a", 1.0) pt.add("b", 2.0, fixed=True) pt.add("c", 3.0) vec = pt.get_free_vector() np.testing.assert_array_equal(vec, [1.0, 3.0]) def test_set_free_vector(self): pt = ParamTable() pt.add("a", 0.0) pt.add("b", 0.0) pt.set_free_vector(np.array([5.0, 7.0])) assert pt.get_value("a") == 5.0 assert pt.get_value("b") == 7.0 def test_n_free(self): pt = ParamTable() pt.add("a", 0.0) pt.add("b", 0.0, fixed=True) pt.add("c", 0.0) assert pt.n_free() == 2 def test_unfix(self): pt = ParamTable() pt.add("a", 1.0) pt.add("b", 2.0) pt.fix("a") assert pt.is_fixed("a") assert "a" not in pt.free_names() pt.unfix("a") assert not pt.is_fixed("a") assert "a" in pt.free_names() assert pt.n_free() == 2 def test_fix_unfix_roundtrip(self): """Fix then unfix preserves value and makes param free again.""" pt = ParamTable() pt.add("x", 5.0) pt.add("y", 3.0) pt.fix("x") pt.set_value("x", 10.0) pt.unfix("x") assert pt.get_value("x") == 10.0 assert "x" in pt.free_names() # x moves to end of free list assert pt.free_names() == ["y", "x"] def test_unfix_noop_if_already_free(self): """Unfixing a free parameter is a no-op.""" pt = ParamTable() pt.add("a", 1.0) pt.unfix("a") assert pt.free_names() == ["a"] assert pt.n_free() == 1