From 3e687bca19e0e2a5d65089cc4aafddfea88297be Mon Sep 17 00:00:00 2001 From: Pau Ferri-Vicedo Date: Mon, 18 Nov 2024 17:43:56 -0500 Subject: [PATCH] added **kwargs parsing on the __init__ for MinDistanceCationAnionFitness class and improved the type on the argument parsing for --cation_indexes and --acid_sites variables so they work with both htvs and command line input formats --- VOID/fitness/threshold.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/VOID/fitness/threshold.py b/VOID/fitness/threshold.py index a348d09..815acbd 100644 --- a/VOID/fitness/threshold.py +++ b/VOID/fitness/threshold.py @@ -1,10 +1,9 @@ import numpy as np import argparse from .base import Fitness -import ipdb THRESHOLD = 1.5 -THRESHOLD_CATAN = 2.0 +THRESHOLD_CATAN = 3.5 DEFAULT_STRUCTURE = "complex" STRUCTURE_CHOICES = ["complex", "guest", "host"] DEFAULT_STEP = False @@ -31,11 +30,10 @@ def __init__( super().__init__() self.threshold = threshold self.step = step + self.extra_args = kwargs if structure not in STRUCTURE_CHOICES: - raise ValueError( - "structure has to be one of: {}".format(", ".join(STRUCTURE_CHOICES)) - ) + raise ValueError("structure has to be one of: {}".format(", ".join(STRUCTURE_CHOICES))) self.structure = structure @staticmethod @@ -84,11 +82,7 @@ def get_cation_anion_distances(self, acid_sites, cation_indexes, distance_matric distances_cation_anion = [] for cation in cation_indexes: - distances = [ - distance_matrices[cation][anion_index] - for anion_list in acid_sites - for anion_index in anion_list - ] + distances = [distance_matrices[cation][anion_index] for anion_list in acid_sites for anion_index in anion_list] distances_cation_anion.append(distances) return distances_cation_anion @@ -117,8 +111,9 @@ def __init__( structure=DEFAULT_STRUCTURE, cation_indexes=None, acid_sites=None, + **kwargs, ): - super().__init__(threshold) + super().__init__(threshold, structure, **kwargs) self.threshold_catan = threshold_catan self.cation_indexes = cation_indexes self.acid_sites = acid_sites @@ -135,13 +130,13 @@ def add_arguments(parser): ) parser.add_argument( "--cation_indexes", - type=list, + type=lambda x: [int(i) for i in x.split(",")], help="indexes for the atoms holding a positive charge in the molecule (default: %(default)s)", default=CATION_INDEXES, ) parser.add_argument( "--acid_sites", - type=list, + type=lambda x: [list(map(int, group.split(","))) for group in x.split(";")], help="list of indexes for the O atoms that hold a negative charge (default: %(default)s)", default=ACID_SITES, ) @@ -163,11 +158,7 @@ def __call__(self, complex): ) if ( - any( - distance < self.threshold_catan - for distance_list in cation_anion_distances - for distance in distance_list - ) + any(distance < self.threshold_catan for distance_list in cation_anion_distances for distance in distance_list) and self.normalize(self.get_distances(complex).min() - self.threshold) > 0 ): print("Optimal cation-anion distance found! Aborting the run") @@ -182,9 +173,7 @@ class MeanDistanceFitness(ThresholdFitness): HELP = "Complexes have positive score if the mean distance between host and guest is above the given threshold" def __call__(self, complex, axis=-1): - return self.normalize( - self.get_distances(complex).min(axis=axis).mean() - self.threshold - ) + return self.normalize(self.get_distances(complex).min(axis=axis).mean() - self.threshold) class SumInvDistanceFitness(ThresholdFitness):