diff --git a/modisco/metaclusterers.py b/modisco/metaclusterers.py index 59ddafd..4ecd7c3 100644 --- a/modisco/metaclusterers.py +++ b/modisco/metaclusterers.py @@ -105,7 +105,8 @@ def fit(self, seqlets): self.get_vector_from_seqlet(x) for x in seqlets])) self._fit(attribute_vectors) - self.fit_called = True + self.fit_called = True + return self def _fit(self, attribute_vectors): raise NotImplementedError() @@ -159,6 +160,11 @@ def weak_vector_to_pattern(self, vector): assert False return to_return + def get_all_possible_compatible_patterns(self, pattern): + all_possible_patterns = list( + itertools.product(*[(x,0) for x in pattern])) + return all_possible_patterns + def check_pattern_compatibility(self, pattern_to_check, reference_pattern): return all([(pattern_elem==reference_elem or reference_elem==0) for pattern_elem, reference_elem @@ -287,17 +293,15 @@ def save_hdf5(self, grp): def _fit(self, attribute_vectors): - all_possible_activity_patterns =\ - list(itertools.product(*[(1,-1,0) for x - in range(attribute_vectors.shape[1])])) + all_possible_activity_patterns = set() activity_pattern_to_attribute_vectors = defaultdict(list) for vector in attribute_vectors: vector_activity_pattern = self.vector_to_pattern(vector) compatible_activity_patterns =\ - self.get_compatible_patterns( - vector_activity_pattern, all_possible_activity_patterns) + self.get_all_possible_compatible_patterns(vector_activity_pattern) for compatible_activity_pattern in compatible_activity_patterns: + all_possible_activity_patterns.add(compatible_activity_pattern) activity_pattern_to_attribute_vectors[ self.pattern_to_str( compatible_activity_pattern)].append(vector) diff --git a/modisco/tfmodisco_workflow/workflow.py b/modisco/tfmodisco_workflow/workflow.py index a42e3b8..02d906a 100644 --- a/modisco/tfmodisco_workflow/workflow.py +++ b/modisco/tfmodisco_workflow/workflow.py @@ -157,7 +157,8 @@ def __init__(self, max_passing_windows_frac=0.2, separate_pos_neg_thresholds=False, verbose=True, - min_seqlets_per_task=None): + min_seqlets_per_task=None, + max_seqlets_during_metacluster_fit=np.inf): if (min_seqlets_per_task is not None): raise DeprecationWarning( @@ -180,6 +181,7 @@ def __init__(self, self.max_passing_windows_frac = max_passing_windows_frac self.separate_pos_neg_thresholds = separate_pos_neg_thresholds self.verbose = verbose + self.max_seqlets_during_metacluster_fit = max_seqlets_during_metacluster_fit self.build() @@ -248,11 +250,12 @@ def __call__(self, task_names, contrib_scores, +" Consider dropping target_seqlet_fdr") - if int(self.min_metacluster_size_frac * len(seqlets)) > self.min_metacluster_size: - print("min_metacluster_size_frac * len(seqlets) = {0} is more than min_metacluster_size={1}.".\ - format(int(self.min_metacluster_size_frac * len(seqlets)), self.min_metacluster_size)) + if int(self.min_metacluster_size_frac + * min(len(seqlets),self.max_seqlets_during_metacluster_fit)) > self.min_metacluster_size: + print("min_metacluster_size_frac * min(len(seqlets),self.max_seqlets_during_metacluster_fit) = {0} is more than min_metacluster_size={1}.".\ + format(int(self.min_metacluster_size_frac * min(len(seqlets),self.max_seqlets_during_metacluster_fit)), self.min_metacluster_size)) print("Using it as a new min_metacluster_size") - self.min_metacluster_size = int(self.min_metacluster_size_frac * len(seqlets)) + self.min_metacluster_size = int(self.min_metacluster_size_frac * min(len(seqlets),self.max_seqlets_during_metacluster_fit)) if (self.weak_threshold_for_counting_sign is None): @@ -288,7 +291,13 @@ def __call__(self, task_names, contrib_scores, weak_threshold_for_counting_sign= weak_threshold_for_counting_sign) - metaclustering_results = metaclusterer.fit_transform(seqlets) + if (len(seqlets) > self.max_seqlets_during_metacluster_fit): + indices = np.random.RandomState(1234).choice( + a=len(seqlets), size=self.max_seqlets_during_metacluster_fit, replace=False) + seqlets_to_metacluster = [seqlets[x] for x in indices] + else: + seqlets_to_metacluster = seqlets + metaclustering_results = metaclusterer.fit(seqlets_to_metacluster).transform(seqlets) metacluster_indices = np.array( metaclustering_results.metacluster_indices) metacluster_idx_to_activity_pattern =\ diff --git a/setup.py b/setup.py index 9b69175..a035367 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ description='TF MOtif Discovery from Importance SCOres', long_description="""Algorithm for discovering consolidated patterns from base-pair-level importance scores""", url='https://github.com/kundajelab/tfmodisco', - version='0.5.1.2', + version='0.5.1.3', packages=find_packages(), package_data={ '': ['cluster/phenograph/louvain/*convert*', 'cluster/phenograph/louvain/*community*', 'cluster/phenograph/louvain/*hierarchy*']