Skip to content

Commit

Permalink
Merge pull request #12 from lantunes/stable_materials
Browse files Browse the repository at this point in the history
Add script for counting the number of valid, unique, and novel CIFs
  • Loading branch information
lantunes authored Jul 20, 2024
2 parents 2582acd + e8ed887 commit 292276d
Showing 1 changed file with 163 additions and 0 deletions.
163 changes: 163 additions & 0 deletions bin/check_valid_unique_novel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import os
import argparse
import gzip
import pickle
import tarfile
import re
from pymatgen.core import Composition, Structure
from pymatgen.analysis.structure_matcher import StructureMatcher
from crystallm import extract_data_formula, is_valid
from postprocess import postprocess
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")


def read_gen_cifs(gen_cifs_path):
cifs = []
if gen_cifs_path.endswith(".tar.gz"):
with tarfile.open(gen_cifs_path, "r:gz") as tar:
for member in tar.getmembers():
if member.isfile() and member.name.endswith(".cif"):
cif = tar.extractfile(member).read().decode()
cifs.append(cif)
else:
for filename in os.listdir(gen_cifs_path):
if filename.endswith(".cif"):
file_path = os.path.join(gen_cifs_path, filename)
with open(file_path, "r") as file:
cif = file.read()
cifs.append(cif)
return cifs


"""
Given a path to a directory containing generated post-processed CIFs, and
a path to a set of CIFs used for training a model, this script does the following:
1. prints the number of given CIFs which are valid
2. prints the number of given CIFs which are unique
3. prints the number of given CIFs which are novel relative to the training CIFs
4. copies the novel CIFs to a specified path
"""
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Count the number of valid, unique, and novel CIFs.")
parser.add_argument("gen_cifs",
help="Path to the directory or .tar.gz file containing the "
"generated post-processed CIF files.")
parser.add_argument("--base",
help="Path to the file containing the set of pre-processed CIFs used for training the model, "
"as a .pkl.gz file containing a Python list of 2-tuples: (ID, CIF).")
parser.add_argument("--out",
help="Path to the directory where the novel CIFs should be copied.")
# the default StructureMatcher settings are based on the MatterGen paper https://arxiv.org/abs/2312.03687
parser.add_argument("--ltol", required=False, default=0.2, type=float,
help="The 'ltol' argument for the pymatgen StructureMatcher")
parser.add_argument("--stol", required=False, default=0.3, type=float,
help="The 'stol' argument for the pymatgen StructureMatcher")
parser.add_argument("--angle_tol", required=False, default=5., type=float,
help="The 'angle_tol' argument for the pymatgen StructureMatcher")
args = parser.parse_args()

gen_cifs_path = args.gen_cifs
base_cifs_path = args.base
out_path = args.out
ltol = args.ltol
stol = args.stol
angle_tol = args.angle_tol

with gzip.open(base_cifs_path, "rb") as f:
base_cifs = pickle.load(f)

print(f"# base CIFs: {len(base_cifs):,}")

base_comps = {}
for _, cif in tqdm(base_cifs, desc="processing base CIFs..."):
cell_composition = extract_data_formula(cif)
formula = Composition(cell_composition).formula
if formula not in base_comps:
base_comps[formula] = []
# strip out any leading or trailing spaces
cif = re.sub(r"^[ \t]+|[ \t]+$", "", cif, flags=re.MULTILINE)
cif = cif.replace(" ", " ")
base_comps[formula].append(cif)

print(f"reading generated CIFs from {gen_cifs_path}...")
gen_cifs = read_gen_cifs(gen_cifs_path)

struct_matcher = StructureMatcher(ltol=ltol, stol=stol, angle_tol=angle_tol)

# keep only the valid CIFs
valid_gen_cifs = []
for gen_cif in gen_cifs:
try:
if is_valid(gen_cif, bond_length_acceptability_cutoff=1.0):
valid_gen_cifs.append(gen_cif)
except Exception:
pass

print(f"# valid CIFs: {len(valid_gen_cifs):,}")

# get unique generated CIFs;
# create a map cell_comp->list of structs, where each struct is unique;
# the number of unique structs is the number of unique materials
gen_comps = {}
for gen_cif in tqdm(valid_gen_cifs, desc="processing generated CIFs..."):
struct = Structure.from_str(gen_cif, fmt="cif")
cell_composition = struct.composition.formula
if cell_composition not in gen_comps:
gen_comps[cell_composition] = [(struct, gen_cif)]
else:
# check that none of the structs match
if not any([
struct_matcher.fit(struct, existing)
for existing, _ in gen_comps[cell_composition]
]):
gen_comps[cell_composition].append((struct, gen_cif))

# unnest structs
gen_structs = [i for g in gen_comps.values() for i in g]

print(f"# unique materials: {len(gen_structs):,}")

# only unique structures can be novel; if we had 1 unique structure out of 100 generated, we wouldn't have
# 100 novel structures--we'd have 1 (if it were indeed novel)

novel_cifs = []
novel_by_composition = 0

for gen_struct, gen_cif in tqdm(gen_structs, desc="checking novelty..."):
gen_formula = gen_struct.composition.formula

if gen_formula not in base_comps:
# this is a novel material as there is no composition match
# print(f"novel by formula: {gen_struct.composition.reduced_formula}")
novel_cifs.append(gen_cif)
novel_by_composition += 1
else:
# for each existing CIF, post-process it and convert to structure and check structure match
is_matched = False
for existing_cif in base_comps[gen_formula]:
existing_cif_post = postprocess(existing_cif, "N/A")
existing_struct = Structure.from_str(existing_cif_post, fmt="cif")
is_matched = struct_matcher.fit(gen_struct, existing_struct)
if is_matched:
# print(f"not novel: {gen_struct.composition.reduced_formula}")
break
if not is_matched:
# print(f"novel: {gen_struct.composition.reduced_formula}")
novel_cifs.append(gen_cif)

print(f"# novel materials: {len(novel_cifs):,}")
print(f"# novel materials by composition alone: {novel_by_composition:,}")

if len(novel_cifs) > 0:
if not os.path.exists(out_path):
print(f"creating {out_path}...")
os.makedirs(out_path)

print(f"writing novel CIFs...")
for i, novel_cif in enumerate(novel_cifs):
output_file_path = os.path.join(out_path, f"novel_{i+1}.cif")
with open(output_file_path, "w") as file:
file.write(novel_cif)

0 comments on commit 292276d

Please sign in to comment.