diff --git a/deepblast/tests/test_nw.py b/deepblast/tests/test_nw.py index c14c851..7f07ae8 100644 --- a/deepblast/tests/test_nw.py +++ b/deepblast/tests/test_nw.py @@ -79,5 +79,32 @@ def test_hessian_needlemanwunsch_function_Arand(self): gradgradcheck(needle, inputs, eps=1e-2) +class TestRegressNeedlemanWunsch(unittest.TestCase): + def setUp(self): + torch.manual_seed(2) + # initialize random replacement costs [0,1] + self.theta = torch.rand(1, + 4, + 6, + requires_grad=True, + dtype=torch.float32).squeeze() + # define gap costs as -3, for theta + max... trick: compute A as follow + self.A = (torch.ones_like(self.theta) * -3) - self.theta + # external knowledge, only works for manual_seed(2) + self.exp_hardmax = -3.2049 + + def test_hardmax(self): + # TODO: use_numba is a global variable and I don't know how to change + # it to False for this test from this method. + needle = NeedlemanWunschDecoder('hardmax') + obs = needle(self.theta, self.A).detach().numpy() + self.assertAlmostEqual(self.exp_hardmax, obs, places=5) + + def test_hardmax_numba(self): + needle = NeedlemanWunschDecoder('hardmax') + obs = needle(self.theta, self.A).detach().numpy() + self.assertAlmostEqual(self.exp_hardmax, obs, places=5) + + if __name__ == "__main__": unittest.main()