Skip to content

Commit

Permalink
Merge pull request #113 from Parietal-INRIA/feat/retrieve-parcels
Browse files Browse the repository at this point in the history
[ENH] Retrieve labels and parcellation image from PairwiseAlignment
  • Loading branch information
pbarbarant authored Dec 5, 2024
2 parents 36d88d0 + fd0c2f3 commit 077501e
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 67 deletions.
42 changes: 34 additions & 8 deletions fmralign/pairwise_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from joblib import Memory, Parallel, delayed
from sklearn.base import BaseEstimator, TransformerMixin, clone
from sklearn.utils.validation import check_is_fitted

from fmralign import alignment_methods
from fmralign._utils import _transform_one_img
Expand Down Expand Up @@ -85,7 +86,7 @@ class PairwiseAlignment(BaseEstimator, TransformerMixin):

def __init__(
self,
alignment_method,
alignment_method="identity",
n_pieces=1,
clustering="kmeans",
mask=None,
Expand Down Expand Up @@ -199,7 +200,7 @@ def fit(self, X, Y):
-------
self
"""
self.pmasker = ParcellationMasker(
self.parcel_masker = ParcellationMasker(
n_pieces=self.n_pieces,
clustering=self.clustering,
mask=self.mask,
Expand All @@ -217,11 +218,13 @@ def fit(self, X, Y):
verbose=self.verbose,
)

parceled_source, parceled_target = self.pmasker.fit_transform([X, Y])
self.masker = self.pmasker.masker_
self.mask = self.pmasker.masker_.mask_img_
self.labels_ = self.pmasker.labels
self.n_pieces = self.pmasker.n_pieces
parceled_source, parceled_target = self.parcel_masker.fit_transform(
[X, Y]
)
self.masker = self.parcel_masker.masker_
self.mask = self.parcel_masker.masker_.mask_img_
self.labels_ = self.parcel_masker.labels
self.n_pieces = self.parcel_masker.n_pieces

self.fit_ = Parallel(
self.n_jobs, prefer="threads", verbose=self.verbose
Expand Down Expand Up @@ -252,7 +255,7 @@ def transform(self, X):
"This instance has not been fitted yet. "
"Please call 'fit' before 'transform'."
)
parceled_data_list = self.pmasker.transform(X)
parceled_data_list = self.parcel_masker.transform(X)
transformed_img = Parallel(
self.n_jobs, prefer="threads", verbose=self.verbose
)(
Expand All @@ -274,3 +277,26 @@ def fit_transform(self):
raise AttributeError(
"type object 'PairwiseAlignment' has no attribute 'fit_transform'"
)

def get_parcellation(self):
"""Get the parcellation masker used for alignment.
Returns
-------
labels: `list` of `int`
Labels of the parcellation masker.
parcellation_img: Niimg-like object
Parcellation image.
"""
if hasattr(self, "parcel_masker"):
check_is_fitted(self)
labels = self.parcel_masker.get_labels()
parcellation_img = self.parcel_masker.get_parcellation()
return labels, parcellation_img
else:
raise AttributeError(
(
"Parcellation has not been computed yet,"
"please fit the alignment estimator first."
)
)
10 changes: 10 additions & 0 deletions fmralign/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,16 @@ def get_labels(self):
)
return self.labels

def get_parcellation(self):
"""Return the parcellation image.
Returns
-------
parcellation : `nibabel.Nifti1Image`
Parcellation image.
"""
return self.masker_.inverse_transform(self.get_labels())

def fit(self, imgs, y=None):
"""Fit the masker and compute the parcellation.
Expand Down
20 changes: 11 additions & 9 deletions fmralign/srm.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def fit(self, imgs):
Length : n_samples
"""
self.pmasker = ParcellationMasker(
self.parcel_masker = ParcellationMasker(
n_pieces=self.n_pieces,
clustering=self.clustering,
mask=self.mask,
Expand All @@ -217,11 +217,11 @@ def fit(self, imgs):
n_jobs=self.n_jobs,
verbose=self.verbose,
)
parceled_data = self.pmasker.fit_transform(imgs)
self.masker_ = self.pmasker.masker_
self.mask = self.pmasker.masker_.mask_img_
self.labels_ = self.pmasker.labels
self.n_pieces = self.pmasker.n_pieces
parceled_data = self.parcel_masker.fit_transform(imgs)
self.masker_ = self.parcel_masker.masker_
self.mask = self.parcel_masker.masker_.mask_img_
self.labels_ = self.parcel_masker.labels
self.n_pieces = self.parcel_masker.n_pieces

outputs = Parallel(
n_jobs=self.n_jobs, prefer="threads", verbose=self.verbose
Expand All @@ -233,7 +233,7 @@ def fit(self, imgs):
for i in range(self.n_pieces)
)

self.labels_ = self.pmasker.labels
self.labels_ = self.parcel_masker.labels
self.fit_ = [output[0] for output in outputs]
self.reduced_sr = [output[1] for output in outputs]
return self
Expand All @@ -242,7 +242,9 @@ def add_subjects(self, imgs):
"""Add subject without recalculating SR"""
for i in range(self.n_pieces):
self.fit_[i]
X_i = _get_parcel_across_subjects(self.pmasker.transform(imgs), i)
X_i = _get_parcel_across_subjects(
self.parcel_masker.transform(imgs), i
)
srm = self.fit_[i]
srm.add_subjects(X_i, self.reduced_sr[i])
return self
Expand All @@ -262,7 +264,7 @@ def transform(self, imgs):

n_comps = self.srm.n_components
aligned_imgs = []
imgs_prep = self.pmasker.transform(imgs)
imgs_prep = self.parcel_masker.transform(imgs)
bag_align = []
for i, piece_srm in enumerate(self.fit_):
X_i = [parceled_data[i].T for parceled_data in imgs_prep]
Expand Down
27 changes: 27 additions & 0 deletions fmralign/tests/test_pairwise_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pytest
from nibabel import Nifti1Image
from nilearn.image import new_img_like
from nilearn.maskers import NiftiMasker

Expand Down Expand Up @@ -103,3 +104,29 @@ def test_models_against_identity():
ground_truth, masker.transform(im_test)
)
assert algo_score >= identity_baseline_score


def test_parcellation_retrieval():
"""Test that PairwiseAlignment returns both the\n
labels and the parcellation image"""
n_pieces = 3
img1, _ = random_niimg((8, 7, 6))
img2, _ = random_niimg((8, 7, 6))
alignment = PairwiseAlignment(n_pieces=n_pieces)
alignment.fit(img1, img2)

labels, parcellation_image = alignment.get_parcellation()
assert isinstance(labels, np.ndarray)
assert len(np.unique(labels)) == n_pieces
assert isinstance(parcellation_image, Nifti1Image)
assert parcellation_image.shape == img1.shape


def test_parcellation_before_fit():
"""Test that PairwiseAlignment raises an error if\n
the parcellation is retrieved before fitting"""
alignment = PairwiseAlignment()
with pytest.raises(
AttributeError, match="Parcellation has not been computed yet"
):
alignment.get_parcellation()
Loading

0 comments on commit 077501e

Please sign in to comment.