Skip to content

Commit

Permalink
Initial functions for the MinDistanceCationAnionFitness class in the …
Browse files Browse the repository at this point in the history
…MC docking
  • Loading branch information
Pau Ferri-Vicedo committed May 28, 2024
1 parent 23c2495 commit f216c88
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 2 deletions.
14 changes: 12 additions & 2 deletions VOID/fitness/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from .base import Fitness
from .threshold import MinDistanceFitness, MeanDistanceFitness, SumInvDistanceFitness
from .target import MinDistanceGaussianTarget, MeanDistanceGaussianTarget, MaxDistanceGaussianTarget
from .threshold import (
MinDistanceFitness,
MeanDistanceFitness,
SumInvDistanceFitness,
MinDistanceCationAnionFitness,
)
from .target import (
MinDistanceGaussianTarget,
MeanDistanceGaussianTarget,
MaxDistanceGaussianTarget,
)
from .union import MultipleFitness

__all__ = [
MinDistanceFitness,
MinDistanceCationAnionFitness,
MeanDistanceFitness,
SumInvDistanceFitness,
MinDistanceGaussianTarget,
Expand Down
80 changes: 80 additions & 0 deletions VOID/fitness/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from .base import Fitness

from pymatgen.core.sites import Site


THRESHOLD = 1.5
DEFAULT_STRUCTURE = "complex"
Expand Down Expand Up @@ -60,6 +62,62 @@ def get_distances(self, complex):
else:
raise ValueError("structure type not supported")

def get_zeolite_oxygens(self, pose):
"""Collect all the O atoms in the structure."""
return [index for index, site in enumerate(pose) if site.species_string == "O"]

def find_cation_index(self, pose, distance_matrices):
"""Identify the cation position in the guest."""
for index, site in enumerate(pose):
element = site.species_string
if element == "C":
bonds = sum(1 for dist in distance_matrices[index] if 0 < dist < 1.6)
if bonds == 3:
return index
return None

def find_acid_sites(self, pose, distance_matrices, zeolite_oxygens):
"""Identify the acid sites in the zeolite."""
acid_oxygens = []
acid_al_indexes = []

for index, site in enumerate(pose):
if site.species_string == "Al":
candidate_oxygens = [
dist_index
for dist_index, dist in enumerate(distance_matrices[index])
if 0 < dist < 1.8 and dist_index in zeolite_oxygens
]
if len(candidate_oxygens) == 4 and all(
all(
not (bond_dist < 1.15 and bond_dist != 0.0)
for bond_dist in distance_matrices[ox_index]
)
for ox_index in candidate_oxygens
): # 1.15 accounts for O-H bond
acid_oxygens.append(candidate_oxygens)
acid_al_indexes.append(index)

return acid_oxygens, acid_al_indexes

def get_catan_distances(self, acid_oxygens, cation_index, distance_matrices):
"""Check the cation-anion distances for the different acid sites in the zeolite."""
distances_catan = []
for acid_al in acid_oxygens:
distances_cation_anion = [
distance_matrices[cation_index][ox_index] for ox_index in acid_al
]
print(
"Distances between cation and acid oxygens are:", distances_cation_anion
)
distances_catan.append(distances_cation_anion)
# if any(dist < 2.0 for dist in distances_cation_anion):
# print("Optimal distance found! Aborting the run")
# return True
# return False

return distances_catan

def normalize(self, value):
if self.step:
return 0 if value > 0 else -np.inf
Expand All @@ -74,6 +132,28 @@ def __call__(self, complex):
return self.normalize(self.get_distances(complex).min() - self.threshold)


class MinDistanceCationAnionFitness(ThresholdFitness):
PARSER_NAME = "min_catan_distance"
HELP = "Complexes have positive score if the minimum distance between host anion and guest cation is above the given threshold"

def __call__(self, complex):
# print(complex.pose)
pose = complex.pose
distance_matrices = complex.pose.distance_matrix
zeolite_oxygens = self.get_zeolite_oxygens(pose)
cation_index = self.find_cation_index(
pose,
distance_matrices,
)
acid_sites, acid_al_indexes = self.find_acid_sites(
pose, distance_matrices, zeolite_oxygens
)
return self.normalize(
min(self.get_catan_distances(acid_sites, cation_index, distance_matrices))
- self.threshold
)


class MeanDistanceFitness(ThresholdFitness):
PARSER_NAME = "mean_distance"
HELP = "Complexes have positive score if the mean distance between host and guest is above the given threshold"
Expand Down

0 comments on commit f216c88

Please sign in to comment.