Skip to content

Commit

Permalink
Make FES calculation more uniform accross methods
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz committed Jan 31, 2024
1 parent ce9524a commit 14c868c
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions pysages/methods/spectral_abf.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def build_force_estimator(method: SpectralABF):
"""
N = method.N
grid = method.grid
dims = grid.shape.size
model = method.model
get_grad = build_grad_evaluator(model)

Expand All @@ -242,17 +243,17 @@ def _estimate_force(state):
return cond(state.pred, interpolate_force, average_force, state)

if method.restraints is None:
estimate_force = _estimate_force
ob_force = jit(lambda state: np.zeros(dims))
else:
lo, hi, kl, kh = method.restraints

def restraints_force(state):
def ob_force(state):
xi = state.xi.reshape(grid.shape.size)
return apply_restraints(lo, hi, kl, kh, xi)

def estimate_force(state):
ob = np.any(np.array(state.ind) == grid.shape) # Out of bounds condition
return cond(ob, restraints_force, _estimate_force, state)
def estimate_force(state):
ob = np.any(np.array(state.ind) == grid.shape) # Out of bounds condition
return cond(ob, ob_force, _estimate_force, state)

return estimate_force

Expand Down Expand Up @@ -303,7 +304,11 @@ def average_forces(hist, Fsum):
return Fsum / np.maximum(hist, 1)

def build_fes_fn(fun):
return jit(lambda x: evaluate(fun, x))
def fes_fn(x):
A = evaluate(fun, x)
return A.max() - A

return jit(fes_fn)

def first_or_all(seq):
return seq[0] if len(seq) == 1 else seq
Expand All @@ -318,7 +323,7 @@ def first_or_all(seq):
fes_fn = build_fes_fn(s.fun)
hists.append(s.hist)
mean_forces.append(average_forces(s.hist, s.Fsum))
free_energies.append(fes_fn(mesh))
free_energies.append(fes_fn(mesh).reshape(grid.shape))
funs.append(s.fun)
fes_fns.append(fes_fn)

Expand All @@ -330,4 +335,5 @@ def first_or_all(seq):
fun=first_or_all(funs),
fes_fn=first_or_all(fes_fns),
)

return numpyfy_vals(ana_result)

0 comments on commit 14c868c

Please sign in to comment.