"""Tests for the expression DAG.""" import math import pytest from kindred_solver.expr import ( ONE, ZERO, Add, Const, Cos, Div, Mul, Neg, Pow, Sin, Sqrt, Sub, Var, ) class TestConst: def test_eval(self): assert Const(3.0).eval({}) == 3.0 def test_diff(self): assert Const(5.0).diff("x") == ZERO def test_eq(self): assert Const(1.0) == Const(1.0) assert Const(1.0) != Const(2.0) def test_vars(self): assert Const(1.0).vars() == set() class TestVar: def test_eval(self): assert Var("x").eval({"x": 7.0}) == 7.0 def test_diff_self(self): assert Var("x").diff("x") == ONE def test_diff_other(self): assert Var("x").diff("y") == ZERO def test_vars(self): assert Var("x").vars() == {"x"} class TestOperators: def test_add(self): x = Var("x") e = x + 3.0 assert isinstance(e, Add) assert e.eval({"x": 2.0}) == 5.0 def test_radd(self): x = Var("x") e = 3.0 + x assert e.eval({"x": 2.0}) == 5.0 def test_sub(self): x = Var("x") e = x - 1.0 assert e.eval({"x": 5.0}) == 4.0 def test_mul(self): x = Var("x") e = x * 2.0 assert e.eval({"x": 3.0}) == 6.0 def test_div(self): x = Var("x") e = x / 2.0 assert e.eval({"x": 6.0}) == 3.0 def test_neg(self): x = Var("x") e = -x assert e.eval({"x": 3.0}) == -3.0 def test_pow(self): x = Var("x") e = x**2 assert e.eval({"x": 3.0}) == 9.0 class TestDiff: def test_add(self): x = Var("x") e = x + Const(5.0) d = e.diff("x").simplify() assert d.eval({}) == 1.0 def test_mul_product_rule(self): x = Var("x") e = x * x # x^2 d = e.diff("x").simplify() # d/dx(x*x) = x + x = 2x assert d.eval({"x": 3.0}) == 6.0 def test_pow_const_exp(self): x = Var("x") e = Pow(x, Const(3.0)) # x^3 d = e.diff("x").simplify() # d/dx(x^3) = 3x^2 assert abs(d.eval({"x": 2.0}) - 12.0) < 1e-10 def test_sin(self): x = Var("x") e = Sin(x) d = e.diff("x").simplify() # d/dx sin(x) = cos(x) assert abs(d.eval({"x": 0.0}) - 1.0) < 1e-10 def test_cos(self): x = Var("x") e = Cos(x) d = e.diff("x").simplify() # d/dx cos(x) = -sin(x) assert abs(d.eval({"x": 0.0}) - 0.0) < 1e-10 def test_div_quotient_rule(self): x = Var("x") e = Div(x, x + Const(1.0)) # x / (x+1) d = e.diff("x") # d/dx x/(x+1) = 1/(x+1)^2 val = d.eval({"x": 1.0}) assert abs(val - 0.25) < 1e-10 def test_sqrt(self): x = Var("x") e = Sqrt(x) d = e.diff("x") # d/dx sqrt(x) = 1/(2*sqrt(x)) val = d.eval({"x": 4.0}) assert abs(val - 0.25) < 1e-10 class TestSimplify: def test_add_zero(self): x = Var("x") e = Add(ZERO, x).simplify() assert isinstance(e, Var) and e.name == "x" def test_mul_one(self): x = Var("x") e = Mul(ONE, x).simplify() assert isinstance(e, Var) and e.name == "x" def test_mul_zero(self): x = Var("x") e = Mul(ZERO, x).simplify() assert e == ZERO def test_const_fold(self): e = Add(Const(2.0), Const(3.0)).simplify() assert isinstance(e, Const) and e.value == 5.0 def test_neg_neg(self): x = Var("x") e = Neg(Neg(x)).simplify() assert isinstance(e, Var) and e.name == "x" def test_pow_zero(self): x = Var("x") e = Pow(x, ZERO).simplify() assert e == ONE def test_pow_one(self): x = Var("x") e = Pow(x, ONE).simplify() assert isinstance(e, Var) and e.name == "x" def test_sub_zero(self): x = Var("x") e = Sub(x, ZERO).simplify() assert isinstance(e, Var) and e.name == "x" class TestComplex: def test_polynomial(self): """Test a quadratic: 3x^2 + 2x + 1, derivative = 6x + 2.""" x = Var("x") e = Const(3.0) * x * x + Const(2.0) * x + Const(1.0) assert abs(e.eval({"x": 2.0}) - 17.0) < 1e-10 d = e.diff("x").simplify() assert abs(d.eval({"x": 2.0}) - 14.0) < 1e-10 def test_vars_set(self): x, y = Var("x"), Var("y") e = x * y + Sin(x) assert e.vars() == {"x", "y"}