Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
moving around orthonormal transform
Browse files Browse the repository at this point in the history
  • Loading branch information
hatemhelal committed Apr 22, 2024
1 parent 10c1bde commit 094493f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
15 changes: 6 additions & 9 deletions mess/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mess.mesh import Mesh, density, density_and_grad, xcmesh_from_pyscf
from mess.scf import otransform_symmetric
from mess.structure import nuclear_energy
from mess.types import FloatNxN, OrthTransform
from mess.types import FloatNxN, OrthNormTransform
from mess.xcfunctional import (
lda_correlation_vwn,
lda_exchange,
Expand Down Expand Up @@ -157,7 +157,6 @@ def build_xcfunc(xc_method: str, basis: Basis, two_electron: TwoElectron) -> eqx


class Hamiltonian(eqx.Module):
X: FloatNxN
H_core: FloatNxN
basis: Basis
two_electron: TwoElectron
Expand All @@ -166,14 +165,11 @@ class Hamiltonian(eqx.Module):
def __init__(
self,
basis: Basis,
otransform: OrthTransform = otransform_symmetric,
xc_method: str = "lda",
):
super().__init__()
S = overlap_basis(basis)
self.X = otransform(S)
self.H_core = kinetic_basis(basis) + nuclear_basis(basis).sum(axis=0)
self.basis = basis
self.H_core = kinetic_basis(basis) + nuclear_basis(basis).sum(axis=0)
self.two_electron = TwoElectron(basis, backend="pyscf")
self.xcfunc = build_xcfunc(xc_method, basis, self.two_electron)

Expand All @@ -187,16 +183,17 @@ def __call__(self, P: FloatNxN) -> float:


@jax.jit
def minimise(H: Hamiltonian):
def minimise(H: Hamiltonian, ont: OrthNormTransform = otransform_symmetric):
def f(Z, _):
C = H.X @ jnl.qr(Z).Q
C = X @ jnl.qr(Z).Q
P = H.basis.density_matrix(C)
return H(P)

X = ont(overlap_basis(H.basis))
solver = optx.BFGS(rtol=1e-5, atol=1e-6)
Z = jnp.eye(H.basis.num_orbitals)
sol = optx.minimise(f, solver, Z)
C = H.X @ jnl.qr(sol.value).Q
C = X @ jnl.qr(sol.value).Q
P = H.basis.density_matrix(C)
E_elec = H(P)
E_total = E_elec + nuclear_energy(H.basis.structure)
Expand Down
2 changes: 1 addition & 1 deletion mess/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

asintarray = partial(jnp.asarray, dtype=jnp.int32)

OrthTransform = Callable[[FloatNxN], FloatNxN]
OrthNormTransform = Callable[[FloatNxN], FloatNxN]


def default_fptype():
Expand Down

0 comments on commit 094493f

Please sign in to comment.