diff --git a/fmralign/pairwise_alignment.py b/fmralign/pairwise_alignment.py index b56da57..d05c803 100644 --- a/fmralign/pairwise_alignment.py +++ b/fmralign/pairwise_alignment.py @@ -6,6 +6,7 @@ import numpy as np from joblib import Memory, Parallel, delayed from sklearn.base import BaseEstimator, TransformerMixin, clone +from sklearn.utils.validation import check_is_fitted from fmralign import alignment_methods from fmralign._utils import _transform_one_img @@ -85,7 +86,7 @@ class PairwiseAlignment(BaseEstimator, TransformerMixin): def __init__( self, - alignment_method, + alignment_method="identity", n_pieces=1, clustering="kmeans", mask=None, @@ -199,7 +200,7 @@ def fit(self, X, Y): ------- self """ - self.pmasker = ParcellationMasker( + self.parcel_masker = ParcellationMasker( n_pieces=self.n_pieces, clustering=self.clustering, mask=self.mask, @@ -217,11 +218,13 @@ def fit(self, X, Y): verbose=self.verbose, ) - parceled_source, parceled_target = self.pmasker.fit_transform([X, Y]) - self.masker = self.pmasker.masker_ - self.mask = self.pmasker.masker_.mask_img_ - self.labels_ = self.pmasker.labels - self.n_pieces = self.pmasker.n_pieces + parceled_source, parceled_target = self.parcel_masker.fit_transform( + [X, Y] + ) + self.masker = self.parcel_masker.masker_ + self.mask = self.parcel_masker.masker_.mask_img_ + self.labels_ = self.parcel_masker.labels + self.n_pieces = self.parcel_masker.n_pieces self.fit_ = Parallel( self.n_jobs, prefer="threads", verbose=self.verbose @@ -252,7 +255,7 @@ def transform(self, X): "This instance has not been fitted yet. " "Please call 'fit' before 'transform'." ) - parceled_data_list = self.pmasker.transform(X) + parceled_data_list = self.parcel_masker.transform(X) transformed_img = Parallel( self.n_jobs, prefer="threads", verbose=self.verbose )( @@ -274,3 +277,26 @@ def fit_transform(self): raise AttributeError( "type object 'PairwiseAlignment' has no attribute 'fit_transform'" ) + + def get_parcellation(self): + """Get the parcellation masker used for alignment. + + Returns + ------- + labels: `list` of `int` + Labels of the parcellation masker. + parcellation_img: Niimg-like object + Parcellation image. + """ + if hasattr(self, "parcel_masker"): + check_is_fitted(self) + labels = self.parcel_masker.get_labels() + parcellation_img = self.parcel_masker.get_parcellation() + return labels, parcellation_img + else: + raise AttributeError( + ( + "Parcellation has not been computed yet," + "please fit the alignment estimator first." + ) + ) diff --git a/fmralign/preprocessing.py b/fmralign/preprocessing.py index 29b6ce2..63521e8 100644 --- a/fmralign/preprocessing.py +++ b/fmralign/preprocessing.py @@ -186,6 +186,16 @@ def get_labels(self): ) return self.labels + def get_parcellation(self): + """Return the parcellation image. + + Returns + ------- + parcellation : `nibabel.Nifti1Image` + Parcellation image. + """ + return self.masker_.inverse_transform(self.get_labels()) + def fit(self, imgs, y=None): """Fit the masker and compute the parcellation. diff --git a/fmralign/srm.py b/fmralign/srm.py index 7b7c74a..c5fc497 100644 --- a/fmralign/srm.py +++ b/fmralign/srm.py @@ -200,7 +200,7 @@ def fit(self, imgs): Length : n_samples """ - self.pmasker = ParcellationMasker( + self.parcel_masker = ParcellationMasker( n_pieces=self.n_pieces, clustering=self.clustering, mask=self.mask, @@ -217,11 +217,11 @@ def fit(self, imgs): n_jobs=self.n_jobs, verbose=self.verbose, ) - parceled_data = self.pmasker.fit_transform(imgs) - self.masker_ = self.pmasker.masker_ - self.mask = self.pmasker.masker_.mask_img_ - self.labels_ = self.pmasker.labels - self.n_pieces = self.pmasker.n_pieces + parceled_data = self.parcel_masker.fit_transform(imgs) + self.masker_ = self.parcel_masker.masker_ + self.mask = self.parcel_masker.masker_.mask_img_ + self.labels_ = self.parcel_masker.labels + self.n_pieces = self.parcel_masker.n_pieces outputs = Parallel( n_jobs=self.n_jobs, prefer="threads", verbose=self.verbose @@ -233,7 +233,7 @@ def fit(self, imgs): for i in range(self.n_pieces) ) - self.labels_ = self.pmasker.labels + self.labels_ = self.parcel_masker.labels self.fit_ = [output[0] for output in outputs] self.reduced_sr = [output[1] for output in outputs] return self @@ -242,7 +242,9 @@ def add_subjects(self, imgs): """Add subject without recalculating SR""" for i in range(self.n_pieces): self.fit_[i] - X_i = _get_parcel_across_subjects(self.pmasker.transform(imgs), i) + X_i = _get_parcel_across_subjects( + self.parcel_masker.transform(imgs), i + ) srm = self.fit_[i] srm.add_subjects(X_i, self.reduced_sr[i]) return self @@ -262,7 +264,7 @@ def transform(self, imgs): n_comps = self.srm.n_components aligned_imgs = [] - imgs_prep = self.pmasker.transform(imgs) + imgs_prep = self.parcel_masker.transform(imgs) bag_align = [] for i, piece_srm in enumerate(self.fit_): X_i = [parceled_data[i].T for parceled_data in imgs_prep] diff --git a/fmralign/tests/test_pairwise_alignment.py b/fmralign/tests/test_pairwise_alignment.py index 02e494c..cd4a375 100644 --- a/fmralign/tests/test_pairwise_alignment.py +++ b/fmralign/tests/test_pairwise_alignment.py @@ -3,6 +3,7 @@ import numpy as np import pytest +from nibabel import Nifti1Image from nilearn.image import new_img_like from nilearn.maskers import NiftiMasker @@ -103,3 +104,29 @@ def test_models_against_identity(): ground_truth, masker.transform(im_test) ) assert algo_score >= identity_baseline_score + + +def test_parcellation_retrieval(): + """Test that PairwiseAlignment returns both the\n + labels and the parcellation image""" + n_pieces = 3 + img1, _ = random_niimg((8, 7, 6)) + img2, _ = random_niimg((8, 7, 6)) + alignment = PairwiseAlignment(n_pieces=n_pieces) + alignment.fit(img1, img2) + + labels, parcellation_image = alignment.get_parcellation() + assert isinstance(labels, np.ndarray) + assert len(np.unique(labels)) == n_pieces + assert isinstance(parcellation_image, Nifti1Image) + assert parcellation_image.shape == img1.shape + + +def test_parcellation_before_fit(): + """Test that PairwiseAlignment raises an error if\n + the parcellation is retrieved before fitting""" + alignment = PairwiseAlignment() + with pytest.raises( + AttributeError, match="Parcellation has not been computed yet" + ): + alignment.get_parcellation() diff --git a/fmralign/tests/test_preprocessing.py b/fmralign/tests/test_preprocessing.py index 692d169..1c39e46 100644 --- a/fmralign/tests/test_preprocessing.py +++ b/fmralign/tests/test_preprocessing.py @@ -10,56 +10,56 @@ def test_init_default_params(): """Test that ParcellationMasker initializes with default parameters""" - pmasker = ParcellationMasker() - assert pmasker.n_pieces == 1 - assert pmasker.clustering == "kmeans" - assert pmasker.mask is None - assert pmasker.smoothing_fwhm is None - assert pmasker.standardize is False - assert pmasker.detrend is False - assert pmasker.labels is None + parcel_masker = ParcellationMasker() + assert parcel_masker.n_pieces == 1 + assert parcel_masker.clustering == "kmeans" + assert parcel_masker.mask is None + assert parcel_masker.smoothing_fwhm is None + assert parcel_masker.standardize is False + assert parcel_masker.detrend is False + assert parcel_masker.labels is None def test_init_custom_params(): """Test that ParcellationMasker initializes with custom parameters""" - pmasker = ParcellationMasker( + parcel_masker = ParcellationMasker( n_pieces=2, clustering="ward", standardize=True, detrend=True, n_jobs=2 ) - assert pmasker.n_pieces == 2 - assert pmasker.clustering == "ward" - assert pmasker.standardize is True - assert pmasker.detrend is True - assert pmasker.n_jobs == 2 + assert parcel_masker.n_pieces == 2 + assert parcel_masker.clustering == "ward" + assert parcel_masker.standardize is True + assert parcel_masker.detrend is True + assert parcel_masker.n_jobs == 2 def test_fit_single_image(): """Test that ParcellationMasker fits a single image""" img, _ = random_niimg((8, 7, 6)) - pmasker = ParcellationMasker(n_pieces=2) - fitted_pmasker = pmasker.fit(img) + parcel_masker = ParcellationMasker(n_pieces=2) + fitted_parcel_masker = parcel_masker.fit(img) - assert hasattr(fitted_pmasker, "masker_") - assert fitted_pmasker.labels is not None - assert isinstance(fitted_pmasker.labels, np.ndarray) - assert len(np.unique(fitted_pmasker.labels)) == 2 # n_pieces=2 + assert hasattr(fitted_parcel_masker, "masker_") + assert fitted_parcel_masker.labels is not None + assert isinstance(fitted_parcel_masker.labels, np.ndarray) + assert len(np.unique(fitted_parcel_masker.labels)) == 2 # n_pieces=2 def test_fit_multiple_images(): """Test that ParcellationMasker fits multiple images""" imgs = [random_niimg((8, 7, 6))[0]] * 3 - pmasker = ParcellationMasker(n_pieces=2) - fitted_pmasker = pmasker.fit(imgs) + parcel_masker = ParcellationMasker(n_pieces=2) + fitted_parcel_masker = parcel_masker.fit(imgs) - assert hasattr(fitted_pmasker, "masker_") - assert fitted_pmasker.labels is not None + assert hasattr(fitted_parcel_masker, "masker_") + assert fitted_parcel_masker.labels is not None def test_transform_single_image(): """Test that ParcellationMasker transforms a single image""" img, _ = random_niimg((8, 7, 6)) - pmasker = ParcellationMasker(n_pieces=2) - pmasker.fit(img) - transformed_data = pmasker.transform(img) + parcel_masker = ParcellationMasker(n_pieces=2) + parcel_masker.fit(img) + transformed_data = parcel_masker.transform(img) assert isinstance(transformed_data, list) assert len(transformed_data) == 1 @@ -69,9 +69,9 @@ def test_transform_single_image(): def test_transform_multiple_images(): """Test that ParcellationMasker transforms multiple images""" imgs = [random_niimg((8, 7, 6))[0]] * 3 - pmasker = ParcellationMasker(n_pieces=2) - pmasker.fit(imgs) - transformed_data = pmasker.transform(imgs) + parcel_masker = ParcellationMasker(n_pieces=2) + parcel_masker.fit(imgs) + transformed_data = parcel_masker.transform(imgs) assert isinstance(transformed_data, list) assert len(transformed_data) == 3 @@ -83,16 +83,16 @@ def test_transform_multiple_images(): def test_get_labels_before_fit(): """Test that ParcellationMasker raises ValueError if get_labels is called before fit""" - pmasker = ParcellationMasker() + parcel_masker = ParcellationMasker() with pytest.raises(ValueError, match="Labels have not been computed yet"): - pmasker.get_labels() + parcel_masker.get_labels() def test_get_labels_after_fit(): img, _ = random_niimg((8, 7, 6)) - pmasker = ParcellationMasker(n_pieces=2) - pmasker.fit(img) - labels = pmasker.get_labels() + parcel_masker = ParcellationMasker(n_pieces=2) + parcel_masker.fit(img) + labels = parcel_masker.get_labels() assert labels is not None assert isinstance(labels, np.ndarray) @@ -108,13 +108,13 @@ def test_different_shaped_images(): different_img = Nifti1Image(different_data, np.eye(4)) imgs = [img, different_img] - pmasker = ParcellationMasker() + parcel_masker = ParcellationMasker() with pytest.raises( NotImplementedError, match="fmralign does not support images of different shapes", ): - pmasker.fit(imgs) + parcel_masker.fit(imgs) def test_clustering_with_mask(): @@ -124,11 +124,13 @@ def test_clustering_with_mask(): clustering_data[5:, :, :] = 0 clustering_img = Nifti1Image(clustering_data, np.eye(4)) img, dummy_mask = random_niimg((8, 7, 6)) - pmasker = ParcellationMasker(clustering=clustering_img, mask=dummy_mask) + parcel_masker = ParcellationMasker( + clustering=clustering_img, mask=dummy_mask + ) with pytest.warns( UserWarning, match="Mask used was bigger than clustering provided" ): - pmasker.fit(img) + parcel_masker.fit(img) def test_memory_caching(tmp_path): @@ -136,8 +138,8 @@ def test_memory_caching(tmp_path): img, _ = random_niimg((8, 7, 6)) # Test that memory caching works memory = Memory(location=str(tmp_path), verbose=0) - pmasker = ParcellationMasker(memory=memory, memory_level=1) - pmasker.fit(img) + parcel_masker = ParcellationMasker(memory=memory, memory_level=1) + parcel_masker.fit(img) # Check that cache directory is not empty cache_files = list(tmp_path.glob("joblib/*")) @@ -148,9 +150,9 @@ def test_memory_caching(tmp_path): def test_parallel_processing(n_jobs): """Test parallel processing with joblib""" imgs = [random_niimg((8, 7, 6))[0]] * 3 - pmasker = ParcellationMasker(n_pieces=2, n_jobs=n_jobs) - pmasker.fit(imgs) - transformed_data = pmasker.transform(imgs) + parcel_masker = ParcellationMasker(n_pieces=2, n_jobs=n_jobs) + parcel_masker.fit(imgs) + transformed_data = parcel_masker.transform(imgs) assert len(transformed_data) == 3 @@ -158,9 +160,9 @@ def test_parallel_processing(n_jobs): def test_smoothing_parameter(): """Test that ParcellationMasker applies smoothing""" img, _ = random_niimg((8, 7, 6)) - pmasker = ParcellationMasker(smoothing_fwhm=4.0) - pmasker.fit(img) - transformed_data = pmasker.transform(img) + parcel_masker = ParcellationMasker(smoothing_fwhm=4.0) + parcel_masker.fit(img) + transformed_data = parcel_masker.transform(img) assert isinstance(transformed_data, list) assert len(transformed_data) == 1 @@ -169,11 +171,30 @@ def test_smoothing_parameter(): def test_standardization(): """Test that ParcellationMasker standardizes data""" img, _ = random_niimg((8, 7, 6, 20)) - pmasker = ParcellationMasker(standardize=True) - pmasker.fit(img) - transformed_data = pmasker.transform(img) + parcel_masker = ParcellationMasker(standardize=True) + parcel_masker.fit(img) + transformed_data = parcel_masker.transform(img) # Check if data is standardized (mean ≈ 0, std ≈ 1) data_array = transformed_data[0].data assert np.abs(np.mean(data_array)) < 1e-5 assert np.abs(np.std(data_array) - 1.0) < 1e-5 + + +def test_get_parcellation(): + """Test that ParcellationMasker returns the parcellation mask""" + n_pieces = 2 + img, _ = random_niimg((8, 7, 6)) + parcel_masker = ParcellationMasker(n_pieces=n_pieces) + parcel_masker.fit(img) + parcellation_img = parcel_masker.get_parcellation() + labels = parcel_masker.get_labels() + + assert isinstance(parcellation_img, Nifti1Image) + assert parcellation_img.shape == img.shape + + masker = parcel_masker.masker_ + data = masker.transform(parcellation_img) + + assert np.allclose(data, labels) + assert len(np.unique(data)) == n_pieces