Skip to content

Commit

Permalink
updated
Browse files Browse the repository at this point in the history
  • Loading branch information
mushroomfire committed Apr 12, 2024
1 parent 94f4718 commit 41d4623
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
20 changes: 10 additions & 10 deletions mdapy/nep/_nep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ struct Atom {
std::vector<double> box, position, potential, force, virial, descriptor;
};

class NepCalculator
class NEPCalculator
{
public:
NepCalculator(std::string);
NEPCalculator(std::string);
void setAtoms(py::array, py::array, py::array);
std::tuple<std::vector<double>, std::vector<double>, std::vector<double>> calculate(py::array, py::array, py::array);
py::dict info;
Expand All @@ -36,7 +36,7 @@ class NepCalculator
bool HAS_SETATOMS=false;
};

NepCalculator::NepCalculator(std::string _model_file)
NEPCalculator::NEPCalculator(std::string _model_file)
{
model_file = _model_file;
calc = NEP3(model_file);
Expand All @@ -54,7 +54,7 @@ NepCalculator::NepCalculator(std::string _model_file)
info["element_list"] = calc.element_list;
}

void NepCalculator::setAtoms(
void NEPCalculator::setAtoms(
py::array _type,
py::array _box,
py::array _position)
Expand Down Expand Up @@ -84,7 +84,7 @@ void NepCalculator::setAtoms(
HAS_SETATOMS = true;
}

std::tuple<std::vector<double>, std::vector<double>, std::vector<double>> NepCalculator::calculate(
std::tuple<std::vector<double>, std::vector<double>, std::vector<double>> NEPCalculator::calculate(
py::array _type,
py::array _box,
py::array _position
Expand All @@ -100,7 +100,7 @@ std::tuple<std::vector<double>, std::vector<double>, std::vector<double>> NepCal
return std::make_tuple(atom.potential, atom.force, atom.virial);
}

std::vector<double> NepCalculator::get_descriptors(
std::vector<double> NEPCalculator::get_descriptors(
py::array _type,
py::array _box,
py::array _position
Expand All @@ -115,10 +115,10 @@ std::vector<double> NepCalculator::get_descriptors(

PYBIND11_MODULE(_nep, m){
m.doc() = "nep";
py::class_<NepCalculator>(m, "NepCalculator")
py::class_<NEPCalculator>(m, "NEPCalculator")
.def(py::init<std::string>())
.def_readonly("info", &NepCalculator::info)
.def("calculate", &NepCalculator::calculate)
.def("get_descriptors", &NepCalculator::get_descriptors)
.def_readonly("info", &NEPCalculator::info)
.def("calculate", &NEPCalculator::calculate)
.def("get_descriptors", &NEPCalculator::get_descriptors)
;
}
6 changes: 3 additions & 3 deletions mdapy/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from tool_function import _check_repeat_cutoff
from replicate import Replicate
from neighbor import Neighbor
from nep._nep import NepCalculator
from nep._nep import NEPCalculator
except Exception:
from .plotset import set_figure
from .tool_function import _check_repeat_cutoff
from .replicate import Replicate
from .neighbor import Neighbor
from _nep import NepCalculator
from _nep import NEPCalculator


@ti.data_oriented
Expand Down Expand Up @@ -704,7 +704,7 @@ class NEP:
def __init__(self, filename) -> None:

self.filename = filename
self._nep = NepCalculator(filename)
self._nep = NEPCalculator(filename)
self.info = self._nep.info
self.rc = max(self.info["radial_cutoff"], self.info["angular_cutoff"])

Expand Down

0 comments on commit 41d4623

Please sign in to comment.