From 14c868c4a798f8ef07065c31f6d74fe3a8adfbe4 Mon Sep 17 00:00:00 2001 From: Pablo Zubieta <8410335+pabloferz@users.noreply.github.com> Date: Tue, 30 Jan 2024 16:51:48 -0600 Subject: [PATCH] Make FES calculation more uniform accross methods --- pysages/methods/spectral_abf.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/pysages/methods/spectral_abf.py b/pysages/methods/spectral_abf.py index 8afe63d2..deccce1e 100644 --- a/pysages/methods/spectral_abf.py +++ b/pysages/methods/spectral_abf.py @@ -228,6 +228,7 @@ def build_force_estimator(method: SpectralABF): """ N = method.N grid = method.grid + dims = grid.shape.size model = method.model get_grad = build_grad_evaluator(model) @@ -242,17 +243,17 @@ def _estimate_force(state): return cond(state.pred, interpolate_force, average_force, state) if method.restraints is None: - estimate_force = _estimate_force + ob_force = jit(lambda state: np.zeros(dims)) else: lo, hi, kl, kh = method.restraints - def restraints_force(state): + def ob_force(state): xi = state.xi.reshape(grid.shape.size) return apply_restraints(lo, hi, kl, kh, xi) - def estimate_force(state): - ob = np.any(np.array(state.ind) == grid.shape) # Out of bounds condition - return cond(ob, restraints_force, _estimate_force, state) + def estimate_force(state): + ob = np.any(np.array(state.ind) == grid.shape) # Out of bounds condition + return cond(ob, ob_force, _estimate_force, state) return estimate_force @@ -303,7 +304,11 @@ def average_forces(hist, Fsum): return Fsum / np.maximum(hist, 1) def build_fes_fn(fun): - return jit(lambda x: evaluate(fun, x)) + def fes_fn(x): + A = evaluate(fun, x) + return A.max() - A + + return jit(fes_fn) def first_or_all(seq): return seq[0] if len(seq) == 1 else seq @@ -318,7 +323,7 @@ def first_or_all(seq): fes_fn = build_fes_fn(s.fun) hists.append(s.hist) mean_forces.append(average_forces(s.hist, s.Fsum)) - free_energies.append(fes_fn(mesh)) + free_energies.append(fes_fn(mesh).reshape(grid.shape)) funs.append(s.fun) fes_fns.append(fes_fn) @@ -330,4 +335,5 @@ def first_or_all(seq): fun=first_or_all(funs), fes_fn=first_or_all(fes_fns), ) + return numpyfy_vals(ana_result)