From 094493fe841f3df375d98112174f5ce060654e65 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Mon, 22 Apr 2024 20:30:48 +0100 Subject: [PATCH] moving around orthonormal transform --- mess/hamiltonian.py | 15 ++++++--------- mess/types.py | 2 +- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/mess/hamiltonian.py b/mess/hamiltonian.py index 47582e3..f8559b8 100644 --- a/mess/hamiltonian.py +++ b/mess/hamiltonian.py @@ -12,7 +12,7 @@ from mess.mesh import Mesh, density, density_and_grad, xcmesh_from_pyscf from mess.scf import otransform_symmetric from mess.structure import nuclear_energy -from mess.types import FloatNxN, OrthTransform +from mess.types import FloatNxN, OrthNormTransform from mess.xcfunctional import ( lda_correlation_vwn, lda_exchange, @@ -157,7 +157,6 @@ def build_xcfunc(xc_method: str, basis: Basis, two_electron: TwoElectron) -> eqx class Hamiltonian(eqx.Module): - X: FloatNxN H_core: FloatNxN basis: Basis two_electron: TwoElectron @@ -166,14 +165,11 @@ class Hamiltonian(eqx.Module): def __init__( self, basis: Basis, - otransform: OrthTransform = otransform_symmetric, xc_method: str = "lda", ): super().__init__() - S = overlap_basis(basis) - self.X = otransform(S) - self.H_core = kinetic_basis(basis) + nuclear_basis(basis).sum(axis=0) self.basis = basis + self.H_core = kinetic_basis(basis) + nuclear_basis(basis).sum(axis=0) self.two_electron = TwoElectron(basis, backend="pyscf") self.xcfunc = build_xcfunc(xc_method, basis, self.two_electron) @@ -187,16 +183,17 @@ def __call__(self, P: FloatNxN) -> float: @jax.jit -def minimise(H: Hamiltonian): +def minimise(H: Hamiltonian, ont: OrthNormTransform = otransform_symmetric): def f(Z, _): - C = H.X @ jnl.qr(Z).Q + C = X @ jnl.qr(Z).Q P = H.basis.density_matrix(C) return H(P) + X = ont(overlap_basis(H.basis)) solver = optx.BFGS(rtol=1e-5, atol=1e-6) Z = jnp.eye(H.basis.num_orbitals) sol = optx.minimise(f, solver, Z) - C = H.X @ jnl.qr(sol.value).Q + C = X @ jnl.qr(sol.value).Q P = H.basis.density_matrix(C) E_elec = H(P) E_total = E_elec + nuclear_energy(H.basis.structure) diff --git a/mess/types.py b/mess/types.py index 5296dfd..4f45e7a 100644 --- a/mess/types.py +++ b/mess/types.py @@ -19,7 +19,7 @@ asintarray = partial(jnp.asarray, dtype=jnp.int32) -OrthTransform = Callable[[FloatNxN], FloatNxN] +OrthNormTransform = Callable[[FloatNxN], FloatNxN] def default_fptype():