-
Notifications
You must be signed in to change notification settings - Fork 18
Refactor Species Class #159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev-gsoc
Are you sure you want to change the base?
Changes from 1 commit
7968da8
4d06543
e4bf935
0645745
df637af
1bc44ff
60fe63f
a4cf3ec
9f9d35a
02b1ad7
bf3e715
b7bb79e
e6db81e
d71480c
df82230
0d78b49
9b71f9b
75ec145
e3d092f
f72034b
fa3441c
762f2b5
9be8ea9
8841c63
ba26664
49b65ef
eb2fb57
192c4b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -268,66 +268,6 @@ def __init__(self, data) -> None: | |
| self.nbasis = self.norba # number of spatial basis functions | ||
|
|
||
|
|
||
| @dataclass(eq=False, order=False) | ||
| class SpeciesData: | ||
| r"""Database entry fields for atomic and ionic species.""" | ||
|
|
||
| # Species info | ||
| elem: str = field(default_factory=default_required("elem", "str")) | ||
| atnum: int = field(default_factory=default_required("atnum", "int")) | ||
| nelec: int = field(default_factory=default_required("nelec", "int")) | ||
| nspin: int = field(default_factory=default_required("nspin", "int")) | ||
| nexc: int = field(default_factory=default_required("nexc", "int")) | ||
|
|
||
| # Scalar properties | ||
| atmass: float = field(default=None) | ||
| cov_radius: float = field(default=None) | ||
| vdw_radius: float = field(default=None) | ||
| at_radius: float = field(default=None) | ||
| polarizability: float = field(default=None) | ||
| dispersion: float = field(default=None) | ||
|
|
||
| # Scalar energy and CDFT-related properties | ||
| energy: float = field(default=None) | ||
| ip: float = field(default=None) | ||
| mu: float = field(default=None) | ||
| eta: float = field(default=None) | ||
|
|
||
| # Basis set name | ||
| obasis_name: str = field(default=None) | ||
|
|
||
| # Radial grid | ||
| rs: ndarray = field(default_factory=default_vector) | ||
|
|
||
| # Orbital energies | ||
| mo_energy_a: ndarray = field(default_factory=default_vector) | ||
| mo_energy_b: ndarray = field(default_factory=default_vector) | ||
|
|
||
| # Orbital occupations | ||
| mo_occs_a: ndarray = field(default_factory=default_vector) | ||
| mo_occs_b: ndarray = field(default_factory=default_vector) | ||
|
|
||
| # Orbital densities | ||
| mo_dens_a: ndarray = field(default_factory=default_matrix) | ||
| mo_dens_b: ndarray = field(default_factory=default_matrix) | ||
| dens_tot: ndarray = field(default_factory=default_matrix) | ||
|
|
||
| # Orbital density gradients | ||
| mo_d_dens_a: ndarray = field(default_factory=default_matrix) | ||
| mo_d_dens_b: ndarray = field(default_factory=default_matrix) | ||
| d_dens_tot: ndarray = field(default_factory=default_matrix) | ||
|
|
||
| # Orbital density Laplacian | ||
| mo_dd_dens_a: ndarray = field(default_factory=default_matrix) | ||
| mo_dd_dens_b: ndarray = field(default_factory=default_matrix) | ||
| dd_dens_tot: ndarray = field(default_factory=default_matrix) | ||
|
|
||
| # Orbital kinetic energy densities | ||
| mo_ked_a: ndarray = field(default_factory=default_matrix) | ||
| mo_ked_b: ndarray = field(default_factory=default_matrix) | ||
| ked_tot: ndarray = field(default_factory=default_matrix) | ||
|
|
||
|
|
||
| class Species: | ||
| r"""Properties of atomic and ionic species.""" | ||
|
|
||
|
|
@@ -345,7 +285,9 @@ def __init__(self, dataset, fields, spinpol=1): | |
|
|
||
| """ | ||
| self._dataset = dataset.lower() | ||
| self._data = SpeciesData(**fields) | ||
| # self._data = SpeciesData(**fields) | ||
| self._data = fields | ||
| print(f"species data: {self._data}") | ||
| self.spinpol = spinpol | ||
| self.ao = _AtomicOrbitals(self._data) | ||
|
|
||
|
|
@@ -838,47 +780,44 @@ def compile_species( | |
| makedirs(path.join(datapath, dataset.lower(), "raw"), exist_ok=True) | ||
| # Import the compile script for the appropriate dataset | ||
| submodule = import_module(f"atomdb.datasets.{dataset}.run") | ||
| creator = import_module(f"atomdb.datasets.{dataset}.h5file_creator") | ||
|
|
||
| dataset_def = submodule.run(elem, charge, mult, nexc, dataset, datapath) | ||
| fields = asdict(dataset_def) | ||
|
|
||
| creator.create_hdf5_file(fields, dataset, elem, charge, mult, nexc) | ||
| fields = submodule.run(elem, charge, mult, nexc, dataset, datapath) | ||
| dump(fields, dataset, elem, charge, mult, nexc) | ||
|
|
||
| # print all fields | ||
| for key, value in fields.items(): | ||
| if isinstance(value, np.ndarray): | ||
| print(f"{key}: shape={value.shape}, first 5 elements={value.flat[:5]}") | ||
| else: | ||
| print(f"{key}: {value}") | ||
|
|
||
| species = Species(dataset, fields) | ||
| return species | ||
| # for key, value in fields.items(): | ||
| # if isinstance(value, np.ndarray): | ||
| # print(f"{key}: shape={value.shape}, first 5 elements={value.flat[:5]}") | ||
| # else: | ||
| # print(f"{key}: {value}") | ||
|
|
||
| # species = Species(dataset, fields) | ||
| # return species | ||
|
|
||
| ## old stuff ## | ||
| # Compile the Species instance and dump the database entry | ||
| # species = submodule.run(elem, charge, mult, nexc, dataset, datapath) | ||
| # dump(species, datapath=datapath) | ||
|
|
||
|
|
||
| def dump(*species, datapath=DEFAULT_DATAPATH): | ||
| r"""Dump the Species instance(s) to a MessagePack file in the database. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| species: Iterable | ||
| Iterables of objects of class `Species` | ||
| datapath : str, optional | ||
| Path to the local AtomDB cache, by default DEFAULT_DATAPATH variable value. | ||
|
|
||
| def dump(fields, dataset, elem, charge, mult, nexc): | ||
|
||
| r"""Dump the Species instance(s) to a hdf5 file in the database. | ||
| """ | ||
| for s in species: | ||
| fn = datafile( | ||
| s._data.elem, s.charge, s.mult, nexc=s.nexc, dataset=s.dataset, datapath=datapath | ||
| ) | ||
| with open(fn, "wb") as f: | ||
| f.write(packb(asdict(s._data), default=encode)) | ||
| creator = import_module(f"atomdb.datasets.{dataset}.h5file_creator") | ||
| creator.create_hdf5_file(fields, dataset, elem, charge, mult, nexc) | ||
|
|
||
| # def dump(*species, datapath=DEFAULT_DATAPATH): | ||
| # r"""Dump the Species instance(s) to a MessagePack file in the database. | ||
| # | ||
| # Parameters | ||
| # ---------- | ||
| # species: Iterable | ||
| # Iterables of objects of class `Species` | ||
| # datapath : str, optional | ||
| # Path to the local AtomDB cache, by default DEFAULT_DATAPATH variable value. | ||
| # | ||
| # """ | ||
| # for s in species: | ||
| # fn = datafile( | ||
| # s._data.elem, s.charge, s.mult, nexc=s.nexc, dataset=s.dataset, datapath=datapath | ||
| # ) | ||
| # with open(fn, "wb") as f: | ||
| # f.write(packb(asdict(s._data), default=encode)) | ||
|
|
||
|
|
||
| def load( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here also replace the hardcoded value for the number of radial points (NPOINTS=10000)