From 5eaf6292fe1aba35ee2c9c6be8a02494f45981d7 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Mon, 29 Apr 2024 11:58:47 +0000 Subject: [PATCH] separate orthonormal transformations --- mess/hamiltonian.py | 4 +-- mess/orthnorm.py | 73 +++++++++++++++++++++++++++++++++++++++++++++ mess/scf.py | 23 ++------------ 3 files changed, 78 insertions(+), 22 deletions(-) create mode 100644 mess/orthnorm.py diff --git a/mess/hamiltonian.py b/mess/hamiltonian.py index 716ab10..ae644b5 100644 --- a/mess/hamiltonian.py +++ b/mess/hamiltonian.py @@ -12,7 +12,7 @@ from mess.integrals import eri_basis, kinetic_basis, nuclear_basis, overlap_basis from mess.interop import to_pyscf from mess.mesh import Mesh, density, density_and_grad, xcmesh_from_pyscf -from mess.scf import otransform_symmetric +from mess.orthnorm import canonical from mess.structure import nuclear_energy from mess.types import FloatNxN, OrthNormTransform from mess.xcfunctional import ( @@ -182,7 +182,7 @@ class Hamiltonian(eqx.Module): def __init__( self, basis: Basis, - ont: OrthNormTransform = otransform_symmetric, + ont: OrthNormTransform = canonical, xc_method: xcstr = "lda", ): super().__init__() diff --git a/mess/orthnorm.py b/mess/orthnorm.py new file mode 100644 index 0000000..07d5e0e --- /dev/null +++ b/mess/orthnorm.py @@ -0,0 +1,73 @@ +# Copyright (c) 2024 Graphcore Ltd. All rights reserved. +import jax.numpy as jnp +import jax.numpy.linalg as jnl + +from mess.types import FloatNxN + +"""Orthonormal transformation. + +Evaluates the transformation matrix :math:`X` that satisfies + +.. math:: \mathbf{X}^T \mathbf{S} \mathbf{X} = \mathbb{I} + +where :math:`\mathbf{S}` is the overlap matrix of the non-orthonormal basis and +:math:`\mathbb{I}` is the identity matrix. + +This module implements a few commonly used orthonormalisation transforms. +""" + + +def canonical(S: FloatNxN) -> FloatNxN: + """Canonical orthonormal transformation + + .. math:: \mathbf{X} = \mathbf{U} \mathbf{s}^{-1/2} + + where :math:`\mathbf{U}` and :math:`\mathbf{s}` are the eigenvectors and + eigenvalues of the overlap matrix :math:`\mathbf{S}`. + + Args: + S (FloatNxN): overlap matrix for the non-orthonormal basis. + + Returns: + FloatNxN: canonical orthonormal transformation matrix + """ + s, U = jnl.eigh(S) + s = jnp.diag(jnp.power(s, -0.5)) + return U @ s + + +def symmetric(S: FloatNxN) -> FloatNxN: + """Symmetric orthonormal transformation + + .. math:: \mathbf{X} = \mathbf{U} \mathbf{s}^{-1/2} \mathbf{U}^T + + where :math:`\mathbf{U}` and :math:`\mathbf{s}` are the eigenvectors and + eigenvalues of the overlap matrix :math:`\mathbf{S}`. + + Args: + S (FloatNxN): overlap matrix for the non-orthonormal basis. + + Returns: + FloatNxN: symmetric orthonormal transformation matrix + """ + s, U = jnl.eigh(S) + s = jnp.diag(jnp.power(s, -0.5)) + return U @ s @ U.T + + +def cholesky(S: FloatNxN) -> FloatNxN: + """Cholesky orthonormal transformation + + .. math:: \mathbf{X} = (\mathbf{L}^{-1})^T + + where :math:`\mathbf{L}` is the lower triangular matrix the satisfies the Cholesky + decomposition of the overlap matrix :math:`\mathbf{S}`. + + Args: + S (FloatNxN): overlap matrix for the non-orthonormal basis. + + Returns: + FloatNxN: cholesky orthonormal transformation matrix + """ + L = jnl.cholesky(S) + return jnl.inv(L).T diff --git a/mess/scf.py b/mess/scf.py index d61c804..aeaa583 100644 --- a/mess/scf.py +++ b/mess/scf.py @@ -1,6 +1,4 @@ # Copyright (c) 2024 Graphcore Ltd. All rights reserved. -from typing import Callable - import jax.numpy as jnp import jax.numpy.linalg as jnl from jax.lax import while_loop @@ -8,28 +6,13 @@ from mess.basis import Basis from mess.integrals import eri_basis, kinetic_basis, nuclear_basis, overlap_basis from mess.structure import nuclear_energy - - -def otransform_canonical(S): - s, U = jnl.eigh(S) - s = jnp.diag(jnp.power(s, -0.5)) - return U @ s - - -def otransform_symmetric(S): - s, U = jnl.eigh(S) - s = jnp.diag(jnp.power(s, -0.5)) - return U @ s @ U.T - - -def otransform_cholesky(S): - L = jnl.cholesky(S) - return jnl.inv(L).T +from mess.orthnorm import cholesky +from mess.types import OrthNormTransform def scf( basis: Basis, - otransform: Callable = otransform_cholesky, + otransform: OrthNormTransform = cholesky, max_iters: int = 32, tolerance: float = 1e-4, ):