Skip to content

Commit

Permalink
added **kwargs parsing on the __init__ for MinDistanceCationAnionFitn…
Browse files Browse the repository at this point in the history
…ess class and improved the type on the argument parsing for --cation_indexes and --acid_sites variables so they work with both htvs and command line input formats
  • Loading branch information
Pau Ferri-Vicedo committed Nov 18, 2024
1 parent 92b69a6 commit 3e687bc
Showing 1 changed file with 10 additions and 21 deletions.
31 changes: 10 additions & 21 deletions VOID/fitness/threshold.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import numpy as np
import argparse
from .base import Fitness
import ipdb

THRESHOLD = 1.5
THRESHOLD_CATAN = 2.0
THRESHOLD_CATAN = 3.5
DEFAULT_STRUCTURE = "complex"
STRUCTURE_CHOICES = ["complex", "guest", "host"]
DEFAULT_STEP = False
Expand All @@ -31,11 +30,10 @@ def __init__(
super().__init__()
self.threshold = threshold
self.step = step
self.extra_args = kwargs

if structure not in STRUCTURE_CHOICES:
raise ValueError(
"structure has to be one of: {}".format(", ".join(STRUCTURE_CHOICES))
)
raise ValueError("structure has to be one of: {}".format(", ".join(STRUCTURE_CHOICES)))
self.structure = structure

@staticmethod
Expand Down Expand Up @@ -84,11 +82,7 @@ def get_cation_anion_distances(self, acid_sites, cation_indexes, distance_matric

distances_cation_anion = []
for cation in cation_indexes:
distances = [
distance_matrices[cation][anion_index]
for anion_list in acid_sites
for anion_index in anion_list
]
distances = [distance_matrices[cation][anion_index] for anion_list in acid_sites for anion_index in anion_list]
distances_cation_anion.append(distances)
return distances_cation_anion

Expand Down Expand Up @@ -117,8 +111,9 @@ def __init__(
structure=DEFAULT_STRUCTURE,
cation_indexes=None,
acid_sites=None,
**kwargs,
):
super().__init__(threshold)
super().__init__(threshold, structure, **kwargs)
self.threshold_catan = threshold_catan
self.cation_indexes = cation_indexes
self.acid_sites = acid_sites
Expand All @@ -135,13 +130,13 @@ def add_arguments(parser):
)
parser.add_argument(
"--cation_indexes",
type=list,
type=lambda x: [int(i) for i in x.split(",")],
help="indexes for the atoms holding a positive charge in the molecule (default: %(default)s)",
default=CATION_INDEXES,
)
parser.add_argument(
"--acid_sites",
type=list,
type=lambda x: [list(map(int, group.split(","))) for group in x.split(";")],
help="list of indexes for the O atoms that hold a negative charge (default: %(default)s)",
default=ACID_SITES,
)
Expand All @@ -163,11 +158,7 @@ def __call__(self, complex):
)

if (
any(
distance < self.threshold_catan
for distance_list in cation_anion_distances
for distance in distance_list
)
any(distance < self.threshold_catan for distance_list in cation_anion_distances for distance in distance_list)
and self.normalize(self.get_distances(complex).min() - self.threshold) > 0
):
print("Optimal cation-anion distance found! Aborting the run")
Expand All @@ -182,9 +173,7 @@ class MeanDistanceFitness(ThresholdFitness):
HELP = "Complexes have positive score if the mean distance between host and guest is above the given threshold"

def __call__(self, complex, axis=-1):
return self.normalize(
self.get_distances(complex).min(axis=axis).mean() - self.threshold
)
return self.normalize(self.get_distances(complex).min(axis=axis).mean() - self.threshold)


class SumInvDistanceFitness(ThresholdFitness):
Expand Down

0 comments on commit 3e687bc

Please sign in to comment.