diff --git a/moldocker/samplers/voronoi.py b/moldocker/samplers/voronoi.py index 731fa01..d76344f 100644 --- a/moldocker/samplers/voronoi.py +++ b/moldocker/samplers/voronoi.py @@ -17,6 +17,7 @@ MIN_VORONOI_RADIUS = 3.0 REMOVE_SPECIES = [] NUM_CLUSTERS = 10 +PYMATGEN_RADII = False class VoronoiSampler(Sampler): @@ -29,11 +30,13 @@ def __init__( probe_radius=PROBE_RADIUS, remove_species=REMOVE_SPECIES, min_radius=MIN_VORONOI_RADIUS, + pymatgen_radii=PYMATGEN_RADII, **kwargs ): self.probe_radius = probe_radius self.remove_species = remove_species self.min_radius = min_radius + self.pymatgen_radii = pymatgen_radii @staticmethod def add_arguments(parser): @@ -62,10 +65,7 @@ def remove_species_from_structure(self): self._structure.remove_species(species) def get_voronoi_structures(self): - try: - radii = self.get_atomic_radii() - except (ValueError, TypeError) as e: - radii = None + radii = self.get_atomic_radii() with suppress_stdout(): nodes, edge_center, face_center = zeopp.get_voronoi_nodes( @@ -75,14 +75,21 @@ def get_voronoi_structures(self): return nodes, edge_center, face_center def get_atomic_radii(self): - bv = BVAnalyzer() - valences = bv.get_valences(self._structure) - elements = [site.species_string for site in self._structure.sites] - - valence_dict = dict(zip(elements, valences)) - radii = {} - for k, v in valence_dict.items(): - radii[k] = float(Specie(k, v).ionic_radius) + if not self.pymatgen_radii: + return None + + try: + bv = BVAnalyzer() + valences = bv.get_valences(self._structure) + elements = [site.species_string for site in self._structure.sites] + + valence_dict = dict(zip(elements, valences)) + radii = {} + for k, v in valence_dict.items(): + radii[k] = float(Specie(k, v).ionic_radius) + + except (ValueError, TypeError) as e: + radii = None return radii diff --git a/moldocker/utils/timing.py b/moldocker/utils/timing.py index 054f1a9..d3515c3 100644 --- a/moldocker/utils/timing.py +++ b/moldocker/utils/timing.py @@ -1,5 +1,6 @@ from timeit import default_timer as timer + def time_fn(fn): def wrapped(*args, **kwargs): start = timer()