From cd6b0a910337abbfb0a19652e6e10941f7bfed99 Mon Sep 17 00:00:00 2001 From: Pablo Zubieta <8410335+pabloferz@users.noreply.github.com> Date: Thu, 1 Feb 2024 16:07:10 -0600 Subject: [PATCH] Couple of fixes (#296) In #292, we introduced a mechanism to count the number of calls to a method, but mistakenly that PR broke `SpectralABF` and `FUNN`, fixing those here. In addition this makes the analysis for `SpectralABF` match the rest of the methods (it was giving the free energy surfaces with opposite sign to the rest of the methods). --- pysages/methods/funn.py | 2 +- pysages/methods/spectral_abf.py | 22 ++++++++++++++-------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/pysages/methods/funn.py b/pysages/methods/funn.py index 9c1c2dde..ad268d54 100644 --- a/pysages/methods/funn.py +++ b/pysages/methods/funn.py @@ -198,7 +198,7 @@ def update(state, data): ) bias = (-Jxi.T @ F).reshape(state.bias.shape) # - return FUNNState(xi, bias, hist, Fsum, F, Wp, state.Wp, nn, state.ncalls) + return FUNNState(xi, bias, hist, Fsum, F, Wp, state.Wp, nn, ncalls) return snapshot, initialize, generalize(update, helpers) diff --git a/pysages/methods/spectral_abf.py b/pysages/methods/spectral_abf.py index 68c21c29..deccce1e 100644 --- a/pysages/methods/spectral_abf.py +++ b/pysages/methods/spectral_abf.py @@ -194,7 +194,7 @@ def update(state, data): ) bias = np.reshape(-Jxi.T @ force, state.bias.shape) # - return SpectralABFState(xi, bias, hist, Fsum, force, Wp, state.Wp, fun, state.ncalls) + return SpectralABFState(xi, bias, hist, Fsum, force, Wp, state.Wp, fun, ncalls) return snapshot, initialize, generalize(update, helpers) @@ -228,6 +228,7 @@ def build_force_estimator(method: SpectralABF): """ N = method.N grid = method.grid + dims = grid.shape.size model = method.model get_grad = build_grad_evaluator(model) @@ -242,17 +243,17 @@ def _estimate_force(state): return cond(state.pred, interpolate_force, average_force, state) if method.restraints is None: - estimate_force = _estimate_force + ob_force = jit(lambda state: np.zeros(dims)) else: lo, hi, kl, kh = method.restraints - def restraints_force(state): + def ob_force(state): xi = state.xi.reshape(grid.shape.size) return apply_restraints(lo, hi, kl, kh, xi) - def estimate_force(state): - ob = np.any(np.array(state.ind) == grid.shape) # Out of bounds condition - return cond(ob, restraints_force, _estimate_force, state) + def estimate_force(state): + ob = np.any(np.array(state.ind) == grid.shape) # Out of bounds condition + return cond(ob, ob_force, _estimate_force, state) return estimate_force @@ -303,7 +304,11 @@ def average_forces(hist, Fsum): return Fsum / np.maximum(hist, 1) def build_fes_fn(fun): - return jit(lambda x: evaluate(fun, x)) + def fes_fn(x): + A = evaluate(fun, x) + return A.max() - A + + return jit(fes_fn) def first_or_all(seq): return seq[0] if len(seq) == 1 else seq @@ -318,7 +323,7 @@ def first_or_all(seq): fes_fn = build_fes_fn(s.fun) hists.append(s.hist) mean_forces.append(average_forces(s.hist, s.Fsum)) - free_energies.append(fes_fn(mesh)) + free_energies.append(fes_fn(mesh).reshape(grid.shape)) funs.append(s.fun) fes_fns.append(fes_fn) @@ -330,4 +335,5 @@ def first_or_all(seq): fun=first_or_all(funs), fes_fn=first_or_all(fes_fns), ) + return numpyfy_vals(ana_result)