From 452ca2eae9bf9e902984921d44148c46182799b4 Mon Sep 17 00:00:00 2001 From: Pau Ferri-Vicedo Date: Wed, 29 May 2024 12:03:00 -0400 Subject: [PATCH] cation/anion identifier functions work through atomtypes dict, get_catan_distances now combines MinDistanceFitness + MinDistanceCationAnionFitness, works --- VOID/fitness/threshold.py | 231 ++++++++++++++++++++++++++++++-------- 1 file changed, 183 insertions(+), 48 deletions(-) diff --git a/VOID/fitness/threshold.py b/VOID/fitness/threshold.py index cecc4ba..6a2d541 100644 --- a/VOID/fitness/threshold.py +++ b/VOID/fitness/threshold.py @@ -62,61 +62,193 @@ def get_distances(self, complex): else: raise ValueError("structure type not supported") - def get_zeolite_oxygens(self, pose): - """Collect all the O atoms in the structure.""" - return [index for index, site in enumerate(pose) if site.species_string == "O"] + def get_atomtypes_indexes(self, pose): + """Collect all the atomtypes indexes in the structure.""" + atomtypes_indexes = {} - def find_cation_index(self, pose, distance_matrices): - """Identify the cation position in the guest.""" for index, site in enumerate(pose): - element = site.species_string - if element == "C": - bonds = sum(1 for dist in distance_matrices[index] if 0 < dist < 1.6) - if bonds == 3: - return index - return None - - def find_acid_sites(self, pose, distance_matrices, zeolite_oxygens): + atom_type = site.species_string + feature = site.label # labels the atomtype as host or guest + if atom_type not in atomtypes_indexes: + atomtypes_indexes[atom_type] = [] + atomtypes_indexes[atom_type].append((index, feature)) + + return atomtypes_indexes + + def find_cation_index(self, distance_matrices, atomtypes_indexes): + """Identify the cation position in the guest.""" + + carbon_indexes = [ + idx for idx, lbl in atomtypes_indexes.get("C", []) if lbl == "guest" + ] + nitrogen_indexes = [ + idx for idx, lbl in atomtypes_indexes.get("N", []) if lbl == "guest" + ] + oxygen_indexes = [ + idx for idx, lbl in atomtypes_indexes.get("O", []) if lbl == "guest" + ] + hydrogen_indexes = [ + idx for idx, lbl in atomtypes_indexes.get("H", []) if lbl == "guest" + ] + for carbon_index in carbon_indexes: + bonds = 0 + for i, dist in enumerate(distance_matrices[carbon_index]): + if ( + i in carbon_indexes + or i in hydrogen_indexes + or i in nitrogen_indexes + or i in oxygen_indexes + ): + if 0 < dist < 1.15 and i in hydrogen_indexes: # exp C-H dist 1.09 A + bonds += 1 + elif ( + 1.5 < dist < 1.6 and i in carbon_indexes + ): # exp C-C dist 1.55 A + bonds += 1 + elif ( + 1.29 < dist < 1.39 and i in carbon_indexes + ): # exp C=C dist 1.34 A + bonds += 2 + elif ( + 1.15 < dist < 1.25 and i in carbon_indexes + ): # exp C≡C dist 1.20 A + bonds += 3 + elif ( + 1.38 < dist < 1.48 and i in nitrogen_indexes + ): # exp C-N dist 1.43 A + bonds += 1 + elif ( + 1.33 < dist < 1.43 and i in nitrogen_indexes + ): # exp C=N dist 1.38 A + bonds += 2 + elif ( + 1.11 < dist < 1.21 and i in nitrogen_indexes + ): # exp C≡N dist 1.16 A + bonds += 3 + elif ( + 1.38 < dist < 1.48 and i in oxygen_indexes + ): # exp C-O dist 1.43 A + bonds += 1 + elif ( + 1.18 < dist < 1.28 and i in oxygen_indexes + ): # exp C=O dist 1.28 A + bonds += 2 + elif ( + 1.08 < dist < 1.18 and i in oxygen_indexes + ): # exp C≡O dist 1.18 A + bonds += 3 + + if bonds == 3: + return carbon_index + + for nitrogen_index in nitrogen_indexes: + bonds = 0 + for i, dist in enumerate(distance_matrices[nitrogen_index]): + if ( + i in carbon_indexes + or i in hydrogen_indexes + or i in nitrogen_indexes + or i in oxygen_indexes + ): + if 0 < dist < 1.05 and i in hydrogen_indexes: # exp N-H dist 1.00 A + bonds += 1 + elif ( + 1.42 < dist < 1.52 and i in carbon_indexes + ): # exp N-N dist 1.47 A + bonds += 1 + elif ( + 1.19 < dist < 1.29 and i in carbon_indexes + ): # exp N=N dist 1.24 A + bonds += 2 + elif ( + 1.05 < dist < 1.15 and i in carbon_indexes + ): # exp N≡N dist 1.10 A + bonds += 3 + elif ( + 1.38 < dist < 1.48 and i in nitrogen_indexes + ): # exp C-N dist 1.43 A + bonds += 1 + elif ( + 1.33 < dist < 1.43 and i in nitrogen_indexes + ): # exp C=N dist 1.38 A + bonds += 2 + elif ( + 1.11 < dist < 1.21 and i in nitrogen_indexes + ): # exp C≡N dist 1.16 A + bonds += 3 + elif ( + 1.39 < dist < 1.49 and i in oxygen_indexes + ): # exp N-O dist 1.44 A + bonds += 1 + elif ( + 1.15 < dist < 1.25 and i in oxygen_indexes + ): # exp C=O dist 1.20 A + bonds += 2 + + if bonds == 4: + return nitrogen_index + + # If no cation index is found, raise an error with a custom message + raise ValueError( + "Cation index could not be found. Please check the molecule you are docking." + ) + + def find_acid_sites(self, distance_matrices, atomtypes_indexes): """Identify the acid sites in the zeolite.""" acid_oxygens = [] acid_al_indexes = [] - for index, site in enumerate(pose): - if site.species_string == "Al": - candidate_oxygens = [ - dist_index - for dist_index, dist in enumerate(distance_matrices[index]) - if 0 < dist < 1.8 and dist_index in zeolite_oxygens - ] - if len(candidate_oxygens) == 4 and all( - all( - not (bond_dist < 1.15 and bond_dist != 0.0) - for bond_dist in distance_matrices[ox_index] - ) - for ox_index in candidate_oxygens - ): # 1.15 accounts for O-H bond - acid_oxygens.append(candidate_oxygens) - acid_al_indexes.append(index) + host_oxygens = [ + idx for idx, lbl in atomtypes_indexes.get("O", []) if lbl == "host" + ] + host_aluminum = [ + idx for idx, lbl in atomtypes_indexes.get("Al", []) if lbl == "host" + ] + ## Room to add more metals if needed + + for al_index in host_aluminum: + candidate_oxygens = [ + dist_index + for dist_index, dist in enumerate(distance_matrices[al_index]) + if 0 < dist < 1.8 and dist_index in host_oxygens + ] + if len(candidate_oxygens) == 4 and all( + all( + not (bond_dist < 1.10 and bond_dist != 0.0) + for bond_dist in distance_matrices[ox_index] + ) + for ox_index in candidate_oxygens + ): # 1.10 accounts for O-H bond + acid_oxygens.append(candidate_oxygens) + acid_al_indexes.append(al_index) return acid_oxygens, acid_al_indexes - def get_catan_distances(self, acid_oxygens, cation_index, distance_matrices): + def get_catan_distances( + self, acid_oxygens, cation_index, distance_matrices, complex + ): """Check the cation-anion distances for the different acid sites in the zeolite.""" distances_catan = [] for acid_al in acid_oxygens: distances_cation_anion = [ distance_matrices[cation_index][ox_index] for ox_index in acid_al ] + + distances_catan.append(distances_cation_anion) + + if ( + any(dist < 2.0 for sublist in distances_catan for dist in sublist) + and self.normalize(self.get_distances(complex).min() - self.threshold) > 0 + ): + print("Optimal distance found! Aborting the run") print( - "Distances between cation and acid oxygens are:", distances_cation_anion + "Distances between cation and acid oxygens are:", + distances_catan, ) - distances_catan.append(distances_cation_anion) - # if any(dist < 2.0 for dist in distances_cation_anion): - # print("Optimal distance found! Aborting the run") - # return True - # return False + return True, distances_catan - return distances_catan + else: + return False, distances_catan def normalize(self, value): if self.step: @@ -134,25 +266,28 @@ def __call__(self, complex): class MinDistanceCationAnionFitness(ThresholdFitness): PARSER_NAME = "min_catan_distance" - HELP = "Complexes have positive score if the minimum distance between host anion and guest cation is above the given threshold" + HELP = "Complexes have positive score if the minimum distance between host anion and guest cation is below the given threshold plus Complexes have positive score if the minimum distance between host and guest is above the given threshold" def __call__(self, complex): - # print(complex.pose) pose = complex.pose distance_matrices = complex.pose.distance_matrix - zeolite_oxygens = self.get_zeolite_oxygens(pose) - cation_index = self.find_cation_index( - pose, - distance_matrices, - ) + atomtypes_indexes = self.get_atomtypes_indexes(pose) + + cation_index = self.find_cation_index(distance_matrices, atomtypes_indexes) acid_sites, acid_al_indexes = self.find_acid_sites( - pose, distance_matrices, zeolite_oxygens + distance_matrices, atomtypes_indexes ) - return self.normalize( - min(self.get_catan_distances(acid_sites, cation_index, distance_matrices)) - - self.threshold + + converged, distances = self.get_catan_distances( + acid_sites, cation_index, distance_matrices, complex ) + if converged: + return 1 + + else: + return -np.inf + class MeanDistanceFitness(ThresholdFitness): PARSER_NAME = "mean_distance"