diff --git a/test/relu.cpp b/test/relu.cpp index a32a380c3..083b6f30d 100644 --- a/test/relu.cpp +++ b/test/relu.cpp @@ -92,6 +92,21 @@ TEST_CASE("def ctor") } } +TEST_CASE("normalise") +{ + { + auto ex = relu(fix(.1_dbl)); + ex = normalise(unfix(ex)); + REQUIRE(ex == .1_dbl); + } + + { + auto ex = relup(fix(-.1_dbl)); + ex = normalise(unfix(ex)); + REQUIRE(ex == 0_dbl); + } +} + TEST_CASE("diff") { auto [x, y] = make_vars("x", "y");