From 1749a3aa8c3726d1394a67d783faa136ea7220a2 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 7 Dec 2024 02:00:59 +0100 Subject: [PATCH 01/23] Add index_by_parcel --- fmralign/template_alignment.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/fmralign/template_alignment.py b/fmralign/template_alignment.py index f07779e..3a6c758 100644 --- a/fmralign/template_alignment.py +++ b/fmralign/template_alignment.py @@ -225,6 +225,29 @@ def _predict_from_template_and_mapping(template, test_index, mapping): return transformed_image +def index_by_parcel(subjects_data): + """ + Index data by parcel. + + Parameters + ---------- + subjects_data: list of list of numpy.ndarray + Each element of the list is the list of parcels + data for one subject. + + Returns + ------- + list of list of numpy.ndarray + Each element of the list is the list of subjects + data for one parcel. + """ + n_pieces = subjects_data[0].n_pieces + return [ + [subject_data[i] for subject_data in subjects_data] + for i in range(n_pieces) + ] + + class TemplateAlignment(BaseEstimator, TransformerMixin): """ Decompose the source images into regions and summarize subjects information From 00751ab2500c4c4535460971e2a2d92628acc8b5 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 7 Dec 2024 02:01:29 +0100 Subject: [PATCH 02/23] Modify fit method --- fmralign/template_alignment.py | 66 ++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/fmralign/template_alignment.py b/fmralign/template_alignment.py index 3a6c758..c087230 100644 --- a/fmralign/template_alignment.py +++ b/fmralign/template_alignment.py @@ -392,39 +392,45 @@ def fit(self, imgs): """ - # Check if the input is a list, if list of lists, concatenate each subjects - # data into one unique image. - if not isinstance(imgs, (list, np.ndarray)) or len(imgs) < 2: - raise ValueError( - "The method TemplateAlignment.fit() need a list input. " - "Each element of the list (Niimg-like or list of Niimgs) " - "is the data for one subject." - ) - else: - if isinstance(imgs[0], (list, np.ndarray)): - imgs = [concat_imgs(img) for img in imgs] + self.parcel_masker = ParcellationMasker( + n_pieces=self.n_pieces, + clustering=self.clustering, + mask=self.mask, + smoothing_fwhm=self.smoothing_fwhm, + standardize=self.standardize, + detrend=self.detrend, + low_pass=self.low_pass, + high_pass=self.high_pass, + t_r=self.t_r, + target_affine=self.target_affine, + target_shape=self.target_shape, + memory=self.memory, + memory_level=self.memory_level, + n_jobs=self.n_jobs, + verbose=self.verbose, + ) - self.masker_ = check_embedded_masker(self) - self.masker_.n_jobs = self.n_jobs # self.n_jobs + subjects_data = self.parcel_masker.fit_transform(imgs) + parcels_data = index_by_parcel(subjects_data) + self.masker = self.parcel_masker.masker_ + self.mask = self.parcel_masker.masker_.mask_img_ + self.labels_ = self.parcel_masker.labels + self.n_pieces = self.parcel_masker.n_pieces - # if masker_ has been provided a mask_img - if self.masker_.mask_img is None: - self.masker_.fit(imgs) - else: - self.masker_.fit() + self.fit_ = Parallel( + self.n_jobs, prefer="threads", verbose=self.verbose + )( + delayed(_fit_local_template)( + parcel_i, + self.n_iter, + self.scale_template, + self.alignment_method, + ) + for parcel_i in parcels_data + ) - self.template, self.template_history = _create_template( - imgs, - self.n_iter, - self.scale_template, - self.alignment_method, - self.n_pieces, - self.clustering, - self.masker_, - self.memory, - self.memory_level, - self.n_jobs, - self.verbose, + self.template, self.template_history = _reconstruct_template( + self.fit_, self.labels_, self.masker ) if self.save_template is not None: self.template.to_filename(self.save_template) From 916fdfc539ad4fef735e702441ced6888245089e Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 7 Dec 2024 02:02:17 +0100 Subject: [PATCH 03/23] Replace _create_template by _fit_local_template --- fmralign/template_alignment.py | 34 +++++++++++----------------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/fmralign/template_alignment.py b/fmralign/template_alignment.py index c087230..8aab339 100644 --- a/fmralign/template_alignment.py +++ b/fmralign/template_alignment.py @@ -84,18 +84,11 @@ def _align_images_to_template( return aligned_imgs -def _create_template( - imgs, +def _fit_local_template( + subjects_data, n_iter, scale_template, alignment_method, - n_pieces, - clustering, - masker, - memory, - memory_level, - n_jobs, - verbose, ): """ Create template through alternate minimization. @@ -127,28 +120,23 @@ def _create_template( List of the intermediate templates computed at the end of each iteration """ - aligned_imgs = imgs + aligned_data = subjects_data template_history = [] for iter in range(n_iter): - template = _rescaled_euclidean_mean( - aligned_imgs, masker, scale_template - ) + template = _rescaled_euclidean_mean(aligned_data, scale_template) if 0 < iter < n_iter - 1: template_history.append(template) - aligned_imgs = _align_images_to_template( - imgs, + aligned_data, subjects_estimators = _align_images_to_template( + subjects_data, template, alignment_method, - n_pieces, - clustering, - masker, - memory, - memory_level, - n_jobs, - verbose, ) - return template, template_history + return { + "template_data": template, + "template_history": template_history, + "estimators": subjects_estimators, + } def _map_template_to_image( From 9f9fe292dd47f64c68c755aa549a6c0dd791226e Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 7 Dec 2024 02:02:52 +0100 Subject: [PATCH 04/23] Add template reconstruction + history --- fmralign/template_alignment.py | 94 +++++++++++++++++++++------------- 1 file changed, 59 insertions(+), 35 deletions(-) diff --git a/fmralign/template_alignment.py b/fmralign/template_alignment.py index 8aab339..9d4b3f4 100644 --- a/fmralign/template_alignment.py +++ b/fmralign/template_alignment.py @@ -8,14 +8,15 @@ import numpy as np from joblib import Memory, Parallel, delayed -from nilearn._utils.masker_validation import check_embedded_masker -from nilearn.image import concat_imgs, index_img, load_img +from nilearn.image import index_img, load_img from sklearn.base import BaseEstimator, TransformerMixin -from fmralign.pairwise_alignment import PairwiseAlignment +from fmralign._utils import _parcels_to_array +from fmralign.pairwise_alignment import PairwiseAlignment, fit_one_piece +from fmralign.preprocessing import ParcellationMasker -def _rescaled_euclidean_mean(imgs, masker, scale_average=False): +def _rescaled_euclidean_mean(subjects_data, scale_average=False): """ Make the Euclidian average of images. @@ -32,34 +33,60 @@ def _rescaled_euclidean_mean(imgs, masker, scale_average=False): Returns ------- - average_img: Niimg + average_img: ndarray Average of imgs, with same shape as one img """ - masked_imgs = [masker.transform(img) for img in imgs] - average_img = np.mean(masked_imgs, axis=0) + average_data = np.mean(subjects_data, axis=0) scale = 1 if scale_average: X_norm = 0 - for img in masked_imgs: - X_norm += np.linalg.norm(img) - X_norm /= len(masked_imgs) - scale = X_norm / np.linalg.norm(average_img) - average_img *= scale + for data in subjects_data: + X_norm += np.linalg.norm(data) + X_norm /= len(subjects_data) + scale = X_norm / np.linalg.norm(average_data) + average_data *= scale - return masker.inverse_transform(average_img) + return average_data + + +def _reconstruct_template(fit, labels, masker): + """ + Reconstruct template from fit output. + + Parameters + ---------- + fit: list of list of np.ndarray + Each element of the list is the list of parcels data for one subject. + labels: np.ndarray + Labels of the parcels. + masker: instance of NiftiMasker or MultiNiftiMasker + Masker to be used on the data. + + Returns + ------- + template_img: 4D Niimg object + Models the barycenter of input imgs + template_history: list of 4D Niimgs + List of the intermediate templates computed at the end of each iteration + """ + template_parcels = [fit_i["template_data"] for fit_i in fit] + template_data = _parcels_to_array(template_parcels, labels) + template_img = masker.inverse_transform(template_data) + + n_iter = len(fit[0]["template_history"]) + template_history = [] + for i in range(n_iter): + template_parcels = [fit_j["template_history"][i] for fit_j in fit] + template_data = _parcels_to_array(template_parcels, labels) + template_history.append(masker.inverse_transform(template_data)) + + return template_img, template_history def _align_images_to_template( - imgs, + subjects_data, template, alignment_method, - n_pieces, - clustering, - masker, - memory, - memory_level, - n_jobs, - verbose, ): """ Convenience function. @@ -67,21 +94,18 @@ def _align_images_to_template( aligning each of them to a common target, the template. All arguments are used in PairwiseAlignment. """ - aligned_imgs = [] - for img in imgs: - piecewise_estimator = PairwiseAlignment( - n_pieces=n_pieces, - alignment_method=alignment_method, - clustering=clustering, - mask=masker, - memory=memory, - memory_level=memory_level, - n_jobs=n_jobs, - verbose=verbose, + aligned_data = [] + piecewise_estimators = [] + for subject_data in subjects_data: + piecewise_estimator = fit_one_piece( + subject_data, + template, + alignment_method, ) - piecewise_estimator.fit(img, template) - aligned_imgs.append(piecewise_estimator.transform(img)) - return aligned_imgs + piecewise_estimator.fit(subject_data, template) + piecewise_estimators.append(piecewise_estimator) + aligned_data.append(piecewise_estimator.transform(subject_data)) + return aligned_data, piecewise_estimators def _fit_local_template( From 7c1cc25232ca57a72ebee7bbb20689bfd53435ce Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 7 Dec 2024 02:39:20 +0100 Subject: [PATCH 05/23] Rename get_parcellation method to get_parcellation_img for clarity --- fmralign/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fmralign/preprocessing.py b/fmralign/preprocessing.py index 63521e8..da3f798 100644 --- a/fmralign/preprocessing.py +++ b/fmralign/preprocessing.py @@ -186,7 +186,7 @@ def get_labels(self): ) return self.labels - def get_parcellation(self): + def get_parcellation_img(self): """Return the parcellation image. Returns From 22d2b6eecfe91e243d9892d67f05058e443970f3 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 7 Dec 2024 02:39:40 +0100 Subject: [PATCH 06/23] Rename transform parameter from X to img for clarity and update related references --- fmralign/pairwise_alignment.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/fmralign/pairwise_alignment.py b/fmralign/pairwise_alignment.py index d05c803..4a7495c 100644 --- a/fmralign/pairwise_alignment.py +++ b/fmralign/pairwise_alignment.py @@ -237,12 +237,12 @@ def fit(self, X, Y): return self - def transform(self, X): + def transform(self, img): """Predict data from X. Parameters ---------- - X: Niimg-like object + img: Niimg-like object Source data Returns @@ -255,7 +255,7 @@ def transform(self, X): "This instance has not been fitted yet. " "Please call 'fit' before 'transform'." ) - parceled_data_list = self.parcel_masker.transform(X) + parceled_data_list = self.parcel_masker.transform(img) transformed_img = Parallel( self.n_jobs, prefer="threads", verbose=self.verbose )( @@ -266,7 +266,6 @@ def transform(self, X): return transformed_img[0] else: return transformed_img - return transformed_img # Make inherited function harmless def fit_transform(self): @@ -291,7 +290,7 @@ def get_parcellation(self): if hasattr(self, "parcel_masker"): check_is_fitted(self) labels = self.parcel_masker.get_labels() - parcellation_img = self.parcel_masker.get_parcellation() + parcellation_img = self.parcel_masker.get_parcellation_img() return labels, parcellation_img else: raise AttributeError( From 931511c3a6b6ff4f1eeb0f1f37936e9cfb14e78a Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 7 Dec 2024 02:40:35 +0100 Subject: [PATCH 07/23] Rename test_get_parcellation to test_get_parcellation_img --- fmralign/tests/test_preprocessing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fmralign/tests/test_preprocessing.py b/fmralign/tests/test_preprocessing.py index 1c39e46..5ff180e 100644 --- a/fmralign/tests/test_preprocessing.py +++ b/fmralign/tests/test_preprocessing.py @@ -181,13 +181,13 @@ def test_standardization(): assert np.abs(np.std(data_array) - 1.0) < 1e-5 -def test_get_parcellation(): +def test_get_parcellatio_img(): """Test that ParcellationMasker returns the parcellation mask""" n_pieces = 2 img, _ = random_niimg((8, 7, 6)) parcel_masker = ParcellationMasker(n_pieces=n_pieces) parcel_masker.fit(img) - parcellation_img = parcel_masker.get_parcellation() + parcellation_img = parcel_masker.get_parcellation_img() labels = parcel_masker.get_labels() assert isinstance(parcellation_img, Nifti1Image) From 5612e751b1169abc37d8f64e50331d89585e43fc Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 7 Dec 2024 02:49:10 +0100 Subject: [PATCH 08/23] Refactor transform method in TemplateAlignment for clarity and functionality; add get_parcellation method to retrieve parcellation masker details. --- fmralign/template_alignment.py | 122 +++++++++++++++++---------------- 1 file changed, 64 insertions(+), 58 deletions(-) diff --git a/fmralign/template_alignment.py b/fmralign/template_alignment.py index 9d4b3f4..63fa2d6 100644 --- a/fmralign/template_alignment.py +++ b/fmralign/template_alignment.py @@ -8,10 +8,11 @@ import numpy as np from joblib import Memory, Parallel, delayed -from nilearn.image import index_img, load_img +from nilearn.image import index_img from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.utils.validation import check_is_fitted -from fmralign._utils import _parcels_to_array +from fmralign._utils import _parcels_to_array, _transform_one_img from fmralign.pairwise_alignment import PairwiseAlignment, fit_one_piece from fmralign.preprocessing import ParcellationMasker @@ -447,7 +448,7 @@ def fit(self, imgs): if self.save_template is not None: self.template.to_filename(self.save_template) - def transform(self, imgs, train_index, test_index): + def transform(self, img, subject_index=None): """ Learn alignment between new subject and template calculated during fit, then predict other conditions for this new subject. @@ -474,66 +475,48 @@ def transform(self, imgs, train_index, test_index): Each Niimg has the same length as the list test_index """ - - if not isinstance(imgs, (list, np.ndarray)): - raise ValueError( - "The method TemplateAlignment.transform() need a list input. " - "Each element of the list (Niimg-like or list of Niimgs) " - "is the data used to align one new subject with images " - "indexed by train_index." - ) - else: - if isinstance(imgs[0], (list, np.ndarray)) and len(imgs[0]) != len( - train_index - ): - raise ValueError( - "Each element of imgs (Niimg-like or list of Niimgs) " - "should have the same length as the length of train_index." - ) - elif load_img(imgs[0]).shape[-1] != len(train_index): - raise ValueError( - "Each element of imgs (Niimg-like or list of Niimgs) " - "should have the same length as the length of train_index." - ) - - template_length = self.template.shape[-1] - if not ( - all(i < template_length for i in test_index) - and all(i < template_length for i in train_index) - ): + if not hasattr(self, "fit_"): raise ValueError( - f"Template has {template_length} images but you provided a " - "greater index in train_index or test_index." + "This instance has not been fitted yet. " + "Please call 'fit' before 'transform'." ) - fitted_mappings = Parallel( - self.n_jobs, prefer="threads", verbose=self.verbose - )( - delayed(_map_template_to_image)( - img, - train_index, - self.template, - self.alignment_method, - self.n_pieces, - self.clustering, - self.masker_, - self.memory, - self.memory_level, - self.n_jobs, - self.verbose, + if subject_index is None: + alignment_estimator = PairwiseAlignment( + n_pieces=self.n_pieces, + alignment_method=self.alignment_method, + clustering=self.parcel_masker.get_parcellation_img(), + mask=self.masker, + smoothing_fwhm=self.smoothing_fwhm, + standardize=self.standardize, + detrend=self.detrend, + target_affine=self.target_affine, + target_shape=self.target_shape, + low_pass=self.low_pass, + high_pass=self.high_pass, + t_r=self.t_r, + memory=self.memory, + memory_level=self.memory_level, + n_jobs=self.n_jobs, + verbose=self.verbose, ) - for img in imgs - ) - - predicted_imgs = Parallel( - self.n_jobs, prefer="threads", verbose=self.verbose - )( - delayed(_predict_from_template_and_mapping)( - self.template, test_index, mapping + alignment_estimator.fit(img, self.template) + return alignment_estimator.transform(img) + else: + parceled_data_list = self.parcel_masker.transform(img) + subject_estimators = [ + fit_i["estimators"][subject_index] for fit_i in self.fit_ + ] + transformed_img = Parallel( + self.n_jobs, prefer="threads", verbose=self.verbose + )( + delayed(_transform_one_img)(parceled_data, subject_estimators) + for parceled_data in parceled_data_list ) - for mapping in fitted_mappings - ) - return predicted_imgs + if len(transformed_img) == 1: + return transformed_img[0] + else: + return transformed_img # Make inherited function harmless def fit_transform(self): @@ -541,3 +524,26 @@ def fit_transform(self): raise AttributeError( "type object 'PairwiseAlignment' has no attribute 'fit_transform'" ) + + def get_parcellation(self): + """Get the parcellation masker used for alignment. + + Returns + ------- + labels: `list` of `int` + Labels of the parcellation masker. + parcellation_img: Niimg-like object + Parcellation image. + """ + if hasattr(self, "parcel_masker"): + check_is_fitted(self) + labels = self.parcel_masker.get_labels() + parcellation_img = self.parcel_masker.get_parcellation_img() + return labels, parcellation_img + else: + raise AttributeError( + ( + "Parcellation has not been computed yet," + "please fit the alignment estimator first." + ) + ) From 169fd45d074d4905ee87f3cb5c22592404c80b45 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 7 Dec 2024 03:02:26 +0100 Subject: [PATCH 09/23] Remove unused _map_template_to_image function to streamline codebase --- fmralign/template_alignment.py | 49 ---------------------------------- 1 file changed, 49 deletions(-) diff --git a/fmralign/template_alignment.py b/fmralign/template_alignment.py index 63fa2d6..5c7c095 100644 --- a/fmralign/template_alignment.py +++ b/fmralign/template_alignment.py @@ -164,55 +164,6 @@ def _fit_local_template( } -def _map_template_to_image( - imgs, - train_index, - template, - alignment_method, - n_pieces, - clustering, - masker, - memory, - memory_level, - n_jobs, - verbose, -): - """ - Learn alignment operator from the template toward new images. - - Parameters - ---------- - imgs: list of 3D Niimgs - Target images to learn mapping from the template to a new subject - train_index: list of int - Matching index between imgs and the corresponding template images to use - to learn alignment. len(train_index) must be equal to len(imgs) - template: list of 3D Niimgs - Learnt in a first step now used as source image - All other arguments are the same are passed to PairwiseAlignment - - - Returns - ------- - mapping: instance of PairwiseAlignment class - Alignment estimator fitted to align the template with the input images - """ - - mapping_image = index_img(template, train_index) - mapping = PairwiseAlignment( - n_pieces=n_pieces, - alignment_method=alignment_method, - clustering=clustering, - mask=masker, - memory=memory, - memory_level=memory_level, - n_jobs=n_jobs, - verbose=verbose, - ) - mapping.fit(mapping_image, imgs) - return mapping - - def _predict_from_template_and_mapping(template, test_index, mapping): """ From a template and an alignment estimator, predict new contrasts. From 3bf5571a86fba77e0bacee8d876e7f8200b15b6e Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 7 Dec 2024 03:04:29 +0100 Subject: [PATCH 10/23] Remove unused _predict_from_template_and_mapping function --- fmralign/template_alignment.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/fmralign/template_alignment.py b/fmralign/template_alignment.py index 5c7c095..1a0731f 100644 --- a/fmralign/template_alignment.py +++ b/fmralign/template_alignment.py @@ -8,7 +8,6 @@ import numpy as np from joblib import Memory, Parallel, delayed -from nilearn.image import index_img from sklearn.base import BaseEstimator, TransformerMixin from sklearn.utils.validation import check_is_fitted @@ -164,31 +163,6 @@ def _fit_local_template( } -def _predict_from_template_and_mapping(template, test_index, mapping): - """ - From a template and an alignment estimator, predict new contrasts. - - Parameters - ---------- - template: list of 3D Niimgs - Learnt in a first step now used to predict some new data - test_index: - Index of the images not used to learn the alignment mapping and so - predictable without overfitting - mapping: instance of PairwiseAlignment class - Alignment estimator that must have been fitted already - - Returns - ------- - transformed_image: list of Niimgs - Prediction corresponding to each template image with index in test_index - once realigned to the new subjects - """ - image_to_transform = index_img(template, test_index) - transformed_image = mapping.transform(image_to_transform) - return transformed_image - - def index_by_parcel(subjects_data): """ Index data by parcel. From a88606c77a753450aed1c2930f0bb268d6940f94 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 7 Dec 2024 15:50:35 +0100 Subject: [PATCH 11/23] Add tests for parcellation retrieval in TemplateAlignment --- fmralign/tests/test_template_alignment.py | 77 +++++++++++++++++++---- 1 file changed, 65 insertions(+), 12 deletions(-) diff --git a/fmralign/tests/test_template_alignment.py b/fmralign/tests/test_template_alignment.py index 245eddd..eb72f9f 100644 --- a/fmralign/tests/test_template_alignment.py +++ b/fmralign/tests/test_template_alignment.py @@ -1,5 +1,6 @@ -# -*- coding: utf-8 -*- +import numpy as np import pytest +from nibabel import Nifti1Image from nilearn.image import concat_imgs, index_img, math_img from nilearn.maskers import NiftiMasker from numpy.testing import assert_array_almost_equal @@ -28,11 +29,11 @@ def test_template_identity(): subs = [sub_1, sub_2, sub_3] - # test euclidian mean function - euclidian_template = _rescaled_euclidean_mean(subs, masker) - assert_array_almost_equal( - ref_template.get_fdata(), euclidian_template.get_fdata() - ) + # # test euclidian mean function + # euclidian_template = _rescaled_euclidean_mean(subs, masker) + # assert_array_almost_equal( + # ref_template.get_fdata(), euclidian_template.get_fdata() + # ) # test different fit() accept list of list of 3D Niimgs as input. algo = TemplateAlignment(alignment_method="identity", mask=masker) @@ -58,7 +59,8 @@ def test_template_identity(): algo.fit(subs) # test template assert_array_almost_equal( - ref_template.get_fdata(), algo.template.get_fdata() + ref_template.get_fdata(), + algo.template.get_fdata(), ) predicted_imgs = algo.transform( [index_img(sub_1, range(8))], @@ -67,7 +69,8 @@ def test_template_identity(): ) ground_truth = index_img(ref_template, range(8, 10)) assert_array_almost_equal( - ground_truth.get_fdata(), predicted_imgs[0].get_fdata() + ground_truth.get_fdata(), + predicted_imgs[0].get_fdata(), ) # test transform() with wrong indexes length or content (on previous fitted algo) @@ -81,7 +84,7 @@ def test_template_identity(): ], ) - for train_ind, test_ind in zip(train_inds, test_inds): + for train_ind, test_ind in zip(train_inds, test_inds, strict=False): with pytest.raises(Exception): assert algo.transform( [index_img(sub_1, range(2))], @@ -98,7 +101,9 @@ def test_template_identity(): ) assert algo.fit([im]) assert algo.transform( - [im], train_index=train_inds[-1], test_index=test_inds[-1] + [im], + train_index=train_inds[-1], + test_index=test_inds[-1], ) @@ -129,16 +134,64 @@ def test_template_closer_to_target(): "diagonal", ]: algo = TemplateAlignment( - alignment_method=alignment_method, n_pieces=3, mask=masker + alignment_method=alignment_method, + n_pieces=3, + mask=masker, ) # Learn template algo.fit(subs) # Assess template is closer to mean than both images template_data = masker.transform(algo.template) template_mean_distance = zero_mean_coefficient_determination( - avg_data, template_data + avg_data, + template_data, ) assert template_mean_distance >= mean_distance_1 assert ( template_mean_distance >= mean_distance_2 - 1.0e-2 ) # for robustness + + +def test_parcellation_retrieval(): + """Test that TemplateAlignment returns both the\n + labels and the parcellation image + """ + n_pieces = 3 + imgs = [random_niimg((8, 7, 6))[0]] * 3 + alignment = TemplateAlignment(n_pieces=n_pieces) + alignment.fit(imgs) + + labels, parcellation_image = alignment.get_parcellation() + assert isinstance(labels, np.ndarray) + assert len(np.unique(labels)) == n_pieces + assert isinstance(parcellation_image, Nifti1Image) + assert parcellation_image.shape == imgs[0].shape + + +def test_parcellation_before_fit(): + """Test that TemplateAlignment raises an error if\n + the parcellation is retrieved before fitting + """ + alignment = TemplateAlignment() + with pytest.raises( + AttributeError, + match="Parcellation has not been computed yet", + ): + alignment.get_parcellation() + + +if __name__ == "__main__": + # imgs = [random_niimg((8, 7, 6, 100))[0]] * 3 + # template_estimator = TemplateAlignment( + # alignment_method="identity", + # n_pieces=2, + # ) + # template_estimator.fit(imgs) + # res_img = template_estimator.transform(random_niimg((8, 7, 6, 100))[0]) + # print(res_img.shape) + # print(template_estimator.template.shape) + # print(template_estimator.get_parcellation()) + test_template_identity() + # test_template_closer_to_target() + # test_parcellation_retrieval() + # test_parcellation_before_fit() From 218fa6cb9a1d2e8ba1c91161ec161025a76533f4 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sun, 8 Dec 2024 21:35:49 +0100 Subject: [PATCH 12/23] Refactor variable names in TemplateAlignment for clarity in parcel data processing --- fmralign/template_alignment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fmralign/template_alignment.py b/fmralign/template_alignment.py index 1a0731f..64ccaf2 100644 --- a/fmralign/template_alignment.py +++ b/fmralign/template_alignment.py @@ -348,8 +348,8 @@ def fit(self, imgs): verbose=self.verbose, ) - subjects_data = self.parcel_masker.fit_transform(imgs) - parcels_data = index_by_parcel(subjects_data) + subjects_parcels = self.parcel_masker.fit_transform(imgs) + parcels_data = _index_by_parcel(subjects_parcels) self.masker = self.parcel_masker.masker_ self.mask = self.parcel_masker.masker_.mask_img_ self.labels_ = self.parcel_masker.labels From bf33bbd1137b1ed8f31ce9021c940e4cb0ce80a9 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sun, 8 Dec 2024 21:36:06 +0100 Subject: [PATCH 13/23] Rename index_by_parcel function to _index_by_parcel for clarity and update parameter name for consistency --- fmralign/template_alignment.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fmralign/template_alignment.py b/fmralign/template_alignment.py index 64ccaf2..6ca92fe 100644 --- a/fmralign/template_alignment.py +++ b/fmralign/template_alignment.py @@ -163,13 +163,13 @@ def _fit_local_template( } -def index_by_parcel(subjects_data): +def _index_by_parcel(subjects_parcels): """ Index data by parcel. Parameters ---------- - subjects_data: list of list of numpy.ndarray + subjects_parcels: list of list of numpy.ndarray Each element of the list is the list of parcels data for one subject. @@ -179,9 +179,9 @@ def index_by_parcel(subjects_data): Each element of the list is the list of subjects data for one parcel. """ - n_pieces = subjects_data[0].n_pieces + n_pieces = subjects_parcels[0].n_pieces return [ - [subject_data[i] for subject_data in subjects_data] + [subject_parcels[i] for subject_parcels in subjects_parcels] for i in range(n_pieces) ] From 1d32ad7bd6903588a8f9874056d706d1bb6eacbb Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sun, 8 Dec 2024 21:36:18 +0100 Subject: [PATCH 14/23] Set default values for parameters in _fit_local_template function for improved usability --- fmralign/template_alignment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fmralign/template_alignment.py b/fmralign/template_alignment.py index 6ca92fe..7cafbcc 100644 --- a/fmralign/template_alignment.py +++ b/fmralign/template_alignment.py @@ -110,9 +110,9 @@ def _align_images_to_template( def _fit_local_template( subjects_data, - n_iter, - scale_template, - alignment_method, + n_iter=2, + scale_template=False, + alignment_method="identity", ): """ Create template through alternate minimization. From 885a05b32b4d888a59dd0890c338324013428401 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sun, 8 Dec 2024 21:37:17 +0100 Subject: [PATCH 15/23] Add unit tests for template alignment functions and parcellation processing --- fmralign/tests/test_template_alignment.py | 189 ++++++++++++++-------- 1 file changed, 125 insertions(+), 64 deletions(-) diff --git a/fmralign/tests/test_template_alignment.py b/fmralign/tests/test_template_alignment.py index eb72f9f..9349371 100644 --- a/fmralign/tests/test_template_alignment.py +++ b/fmralign/tests/test_template_alignment.py @@ -1,20 +1,107 @@ import numpy as np import pytest from nibabel import Nifti1Image -from nilearn.image import concat_imgs, index_img, math_img +from nilearn.image import concat_imgs, math_img from nilearn.maskers import NiftiMasker from numpy.testing import assert_array_almost_equal +from fmralign._utils import ParceledData +from fmralign.preprocessing import ParcellationMasker from fmralign.template_alignment import ( TemplateAlignment, + _align_images_to_template, + _fit_local_template, + _index_by_parcel, + _reconstruct_template, _rescaled_euclidean_mean, ) from fmralign.tests.utils import ( random_niimg, + sample_parceled_data, + sample_subjects_data, zero_mean_coefficient_determination, ) +@pytest.mark.parametrize("scale_average", [True, False]) +def test_rescaled_euclidean_mean(scale_average): + subjects_data = sample_subjects_data() + average_data = _rescaled_euclidean_mean(subjects_data) + assert average_data.shape == subjects_data[0].shape + assert average_data.dtype == subjects_data[0].dtype + + if scale_average is False: + assert np.allclose(average_data, np.mean(subjects_data, axis=0)) + + +def test_reconstruct_template(): + n_subjects = 3 + n_iter = 3 + n_pieces = 2 + imgs = [random_niimg((8, 7, 6, 20))[0]] * n_subjects + parcel_masker = ParcellationMasker(n_pieces=n_pieces) + subjects_parcels = parcel_masker.fit_transform(imgs) + parcels_subjects = _index_by_parcel(subjects_parcels) + masker = parcel_masker.masker_ + labels = parcel_masker.labels + + fit = [ + _fit_local_template(parcel_i, n_iter=n_iter) + for parcel_i in parcels_subjects + ] + template, template_history = _reconstruct_template(fit, labels, masker) + + assert template.shape == imgs[0].shape + assert len(template_history) == n_iter - 2 + for template_i in template_history: + assert template_i.shape == imgs[0].shape + + +def test_align_images_to_template(): + subjects_data = sample_subjects_data() + template = _rescaled_euclidean_mean(subjects_data) + aligned_data, subjects_estimators = _align_images_to_template( + subjects_data, + template, + alignment_method="identity", + ) + assert len(aligned_data) == len(subjects_data) + assert len(subjects_estimators) == len(subjects_data) + assert aligned_data[0].shape == subjects_data[0].shape + + +def test_fit_local_template(): + n_subjects = 3 + n_iter = 3 + subjects_data = sample_subjects_data(n_subjects=n_subjects) + fit = _fit_local_template( + subjects_data, + n_iter=n_iter, + alignment_method="identity", + scale_template=False, + ) + template_data = fit["template_data"] + template_history = fit["template_history"] + estimators = fit["estimators"] + + assert template_data.shape == subjects_data[0].shape + assert len(template_history) == n_iter - 2 + assert len(estimators) == n_subjects + + +def test_index_by_parcel(): + n_subjects = 3 + n_pieces = 2 + subjects_parcels = [ + ParceledData(*sample_parceled_data(n_pieces)) + for _ in range(n_subjects) + ] + parcels_subjects = _index_by_parcel(subjects_parcels) + assert len(parcels_subjects) == n_pieces + assert len(parcels_subjects[0]) == n_subjects + assert parcels_subjects[0][0].shape == subjects_parcels[0][0].shape + + def test_template_identity(): n = 10 im, mask_img = random_niimg((6, 5, 3)) @@ -29,15 +116,9 @@ def test_template_identity(): subs = [sub_1, sub_2, sub_3] - # # test euclidian mean function - # euclidian_template = _rescaled_euclidean_mean(subs, masker) - # assert_array_almost_equal( - # ref_template.get_fdata(), euclidian_template.get_fdata() - # ) - # test different fit() accept list of list of 3D Niimgs as input. algo = TemplateAlignment(alignment_method="identity", mask=masker) - algo.fit([n * [im]] * 3) + algo.fit([concat_imgs(n * [im])] * 3) # test template assert_array_almost_equal(sub_1.get_fdata(), algo.template.get_fdata()) @@ -55,55 +136,53 @@ def test_template_identity(): for args in args_list: algo = TemplateAlignment(**args) - # Learning a template which is algo.fit(subs) # test template assert_array_almost_equal( ref_template.get_fdata(), algo.template.get_fdata(), ) - predicted_imgs = algo.transform( - [index_img(sub_1, range(8))], - train_index=range(8), - test_index=range(8, 10), + predicted_imgs = algo.transform(ref_template) + assert_array_almost_equal( + predicted_imgs.get_fdata(), + ref_template.get_fdata(), ) - ground_truth = index_img(ref_template, range(8, 10)) + predicted_imgs = algo.transform(ref_template, subject_index=1) assert_array_almost_equal( - ground_truth.get_fdata(), - predicted_imgs[0].get_fdata(), + predicted_imgs.get_fdata(), + ref_template.get_fdata(), ) - # test transform() with wrong indexes length or content (on previous fitted algo) - train_inds, test_inds = ( - [[0, 1], [1, 10], [4, 11], [0, 1, 2]], - [ - [6, 8, 29], - [4, 6], - [4, 11], - [4, 5], - ], + +def test_template_diagonal(): + n = 10 + im, mask_img = random_niimg((6, 5, 3)) + + sub_1 = concat_imgs(n * [im]) + sub_2 = math_img("2 * img", img=sub_1) + sub_3 = math_img("3 * img", img=sub_1) + + ref_template = sub_2 + masker = NiftiMasker(mask_img=mask_img) + masker.fit() + + subs = [sub_1, sub_2, sub_3] + + # Test without subject_index + algo = TemplateAlignment(alignment_method="diagonal", mask=masker) + algo.fit(subs) + predicted_imgs = algo.transform(sub_1, subject_index=None) + assert_array_almost_equal( + ref_template.get_fdata(), + predicted_imgs.get_fdata(), ) - for train_ind, test_ind in zip(train_inds, test_inds, strict=False): - with pytest.raises(Exception): - assert algo.transform( - [index_img(sub_1, range(2))], - train_index=train_ind, - test_index=test_ind, - ) - - # test wrong images input in fit() and transform method - with pytest.raises(Exception): - assert algo.transform( - [n * [im]] * 2, - train_index=train_inds[-1], - test_index=test_inds[-1], - ) - assert algo.fit([im]) - assert algo.transform( - [im], - train_index=train_inds[-1], - test_index=test_inds[-1], + # Test with subject_index + for i, sub in enumerate(subs): + predicted_imgs = algo.transform(sub, subject_index=i) + assert_array_almost_equal( + predicted_imgs.get_fdata(), + ref_template.get_fdata(), ) @@ -121,8 +200,7 @@ def test_template_closer_to_target(): sub_1 = masker.transform(subject_1) sub_2 = masker.transform(subject_2) subs = [subject_1, subject_2] - average_img = _rescaled_euclidean_mean(subs, masker) - avg_data = masker.transform(average_img) + avg_data = np.mean([sub_1, sub_2], axis=0) mean_distance_1 = zero_mean_coefficient_determination(sub_1, avg_data) mean_distance_2 = zero_mean_coefficient_determination(sub_2, avg_data) @@ -178,20 +256,3 @@ def test_parcellation_before_fit(): match="Parcellation has not been computed yet", ): alignment.get_parcellation() - - -if __name__ == "__main__": - # imgs = [random_niimg((8, 7, 6, 100))[0]] * 3 - # template_estimator = TemplateAlignment( - # alignment_method="identity", - # n_pieces=2, - # ) - # template_estimator.fit(imgs) - # res_img = template_estimator.transform(random_niimg((8, 7, 6, 100))[0]) - # print(res_img.shape) - # print(template_estimator.template.shape) - # print(template_estimator.get_parcellation()) - test_template_identity() - # test_template_closer_to_target() - # test_parcellation_retrieval() - # test_parcellation_before_fit() From 5883e48b393d74bcc832831334715e596f1787ad Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sun, 8 Dec 2024 21:37:31 +0100 Subject: [PATCH 16/23] Add function to sample subjects data for testing purposes --- fmralign/tests/utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fmralign/tests/utils.py b/fmralign/tests/utils.py index 35133c6..61af947 100644 --- a/fmralign/tests/utils.py +++ b/fmralign/tests/utils.py @@ -132,3 +132,9 @@ def sample_parceled_data(n_pieces=1): data = masker.fit_transform(img) labels = _make_parcellation(img, "kmeans", n_pieces, masker) return data, masker, labels + + +def sample_subjects_data(n_subjects=3): + """Sample data in one parcel for n_subjects""" + subjects_data = [np.random.rand(10, 20) for _ in range(n_subjects)] + return subjects_data From 9411aee12a208401ad98fc0f5bd91b73d89e26fc Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sun, 8 Dec 2024 22:01:15 +0100 Subject: [PATCH 17/23] Update documentation for image processing functions to clarify parameter types and improve readability --- fmralign/template_alignment.py | 93 +++++++++++++++++++--------------- 1 file changed, 51 insertions(+), 42 deletions(-) diff --git a/fmralign/template_alignment.py b/fmralign/template_alignment.py index 7cafbcc..779a789 100644 --- a/fmralign/template_alignment.py +++ b/fmralign/template_alignment.py @@ -18,22 +18,19 @@ def _rescaled_euclidean_mean(subjects_data, scale_average=False): """ - Make the Euclidian average of images. + Make the Euclidian average of `numpy.ndarray`. Parameters ---------- - imgs: list of Niimgs - Each img is 3D by default, but can also be 4D. - masker: instance of NiftiMasker or MultiNiftiMasker - Masker to be used on the data. + subjects_data: `list` of `numpy.ndarray` + Each element of the list is the data for one subject. scale_average: boolean - If true, the returned average is scaled to have the average norm of imgs - If false, it will usually have a smaller norm than initial average - because noise will cancel across images + If true, average is rescaled so that it keeps the same norm as the + average of training images. Returns ------- - average_img: ndarray + average_data: ndarray Average of imgs, with same shape as one img """ average_data = np.mean(subjects_data, axis=0) @@ -55,9 +52,9 @@ def _reconstruct_template(fit, labels, masker): Parameters ---------- - fit: list of list of np.ndarray + fit: list of list of numpy.ndarray Each element of the list is the list of parcels data for one subject. - labels: np.ndarray + labels: numpy.ndarray Labels of the parcels. masker: instance of NiftiMasker or MultiNiftiMasker Masker to be used on the data. @@ -90,9 +87,24 @@ def _align_images_to_template( ): """ Convenience function. - For a list of images, return the list of estimators (PairwiseAlignment instances) + For a list of ndarrays, return the list of alignment estimators aligning each of them to a common target, the template. - All arguments are used in PairwiseAlignment. + + Parameters + ---------- + subjects_data: `list` of `numpy.ndarray` + Each element of the list is the data for one subject. + template: `numpy.ndarray` + The target data. + alignment_method: string + Algorithm used to perform alignment between sources and template. + + Returns + ------- + aligned_data: `list` of `numpy.ndarray` + List of aligned data. + piecewise_estimators: `list` of `PairwiseAlignment` + List of `Alignment` estimators. """ aligned_data = [] piecewise_estimators = [] @@ -117,31 +129,33 @@ def _fit_local_template( """ Create template through alternate minimization. Compute iteratively : - * T minimizing sum(||R_i X_i-T||) which is the mean of aligned images (RX_i) + * T minimizing sum(||R X-T||) which is the mean of aligned images (RX) * align initial images to new template T - (find transform R_i minimizing ||R_i X_i-T|| for each img X_i) + (find transform R minimizing ||R X-T|| for each img X) Parameters ---------- - imgs: List of Niimg-like objects - See http://nilearn.github.io/manipulating_images/input_output.html - source data. Every img must have the same length (n_sample) - scale_template: boolean - If true, template is rescaled after each inference so that it keeps - the same norm as the average of training images. + imgs: `list` of `numpy.ndarray` + Each element of the list is the data for one subject. n_iter: int Number of iterations in the alternate minimization. Each image is aligned n_iter times to the evolving template. If n_iter = 0, the template is simply the mean of the input images. - All other arguments are the same are passed to PairwiseAlignment + scale_template: boolean + If true, template is rescaled after each inference so that it keeps + the same norm as the average of training images. + alignment_method: string + Algorithm used to perform alignment between sources and template. Returns ------- - template: list of 3D Niimgs of length (n_sample) - Models the barycenter of input imgs - template_history: list of list of 3D Niimgs - List of the intermediate templates computed at the end of each iteration + template_data: `numpy.ndarray` + Template data. + template_history: `list` of `numpy.ndarray` + List of the intermediate templates computed at the end of each iteration. + piecewise_estimators: `list` of `PairwiseAlignment` + List of `Alignment` estimators. """ aligned_data = subjects_data @@ -375,29 +389,24 @@ def fit(self, imgs): def transform(self, img, subject_index=None): """ - Learn alignment between new subject and template calculated during fit, - then predict other conditions for this new subject. - Alignment is learnt between imgs and conditions in the template indexed by train_index. - Prediction correspond to conditions in the template index by test_index. + Transform a (new) subject image into the template space. Parameters ---------- - imgs: List of 3D Niimg-like objects - Target subjects known data. - Every img must have length (number of sample) train_index. - train_index: list of ints - Indexes of the 3D samples used to map each img to the template. - Every index should be smaller than the number of images in the template. - test_index: list of ints - Indexes of the 3D samples to predict from the template and the mapping. - Every index should be smaller than the number of images in the template. + img: 4D Niimg-like object + Subject image. + subject_index: int, optional (default = None) + Index of the subject to be transformed. It should + correspond to the index of the subject in the list of + subjects used to fit the template. If None, a new + `PairwiseAlignment` object is fitted between the new + subject and the template before transforming. Returns ------- - predicted_imgs: List of 3D Niimg-like objects - Target subjects predicted data. - Each Niimg has the same length as the list test_index + predicted_imgs: 4D Niimg object + Transformed data. """ if not hasattr(self, "fit_"): From e55d98247229c221291e4d53394d00909ea4a1c0 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 9 Dec 2024 16:06:24 +0100 Subject: [PATCH 18/23] Enhance tutorial on template-based prediction by improving clarity and detail, correcting typos, and updating variable names for consistency --- examples/plot_template_alignment.py | 80 +++++++++++------------------ 1 file changed, 29 insertions(+), 51 deletions(-) diff --git a/examples/plot_template_alignment.py b/examples/plot_template_alignment.py index 21225ec..f865021 100644 --- a/examples/plot_template_alignment.py +++ b/examples/plot_template_alignment.py @@ -1,13 +1,14 @@ # -*- coding: utf-8 -*- - """ Template-based prediction. ========================== -In this tutorial, we show how to better predict new contrasts for a target -subject using many source subjects corresponding contrasts. For this purpose, -we create a template to which we align the target subject, using shared information. -We then predict new images for the target and compare them to a baseline. +In this tutorial, we show how to improve inter-subject similarity using a template +computed across multiple source subjects. For this purpose, we create a template +using Procrustes alignment (hyperalignment) to which we align the target subject, +using shared information. We then compare the voxelwise similarity between the +target subject and the template to the similarity between the target subject and +the anatomical euclidean average of the source subjects. We mostly rely on Python common packages and on nilearn to handle functional data in a clean fashion. @@ -36,7 +37,7 @@ ) ############################################################################### -# Definine a masker +# Define a masker # ----------------- # We define a nilearn masker that will be used to handle relevant data. # For more information, visit : @@ -64,22 +65,17 @@ template_train = [] for i in range(5): template_train.append(concat_imgs(imgs[i])) -target_train = df[df.subject == "sub-07"][df.acquisition == "ap"].path.values -# For subject sub-07, we split it in two folds: -# - target train: sub-07 AP contrasts, used to learn alignment to template -# - target test: sub-07 PA contrasts, used as a ground truth to score predictions -# We make a single 4D Niimg from our list of 3D filenames +# sub-07 will be our left-out subject. +# We make a single 4D Niimg from our list of 3D filenames. -target_train = concat_imgs(target_train) -target_test = df[df.subject == "sub-07"][df.acquisition == "pa"].path.values +left_out_subject = concat_imgs(imgs[5]) ############################################################################### # Compute a baseline (average of subjects) # ---------------------------------------- # We create an image with as many contrasts as any subject representing for # each contrast the average of all train subjects maps. -# import numpy as np @@ -92,70 +88,53 @@ # --------------------------------------------- # We define an estimator using the class TemplateAlignment: # * We align the whole brain through 'multiple' local alignments. -# * These alignments are calculated on a parcellation of the brain in 150 pieces, +# * These alignments are calculated on a parcellation of the brain in 50 pieces, # this parcellation creates group of functionnally similar voxels. # * The template is created iteratively, aligning all subjects data into a # common space, from which the template is inferred and aligning again to this # new template space. # -from nilearn.image import index_img - from fmralign.template_alignment import TemplateAlignment +# We use Procrustes/scaled orthogonal alignment method template_estim = TemplateAlignment( - n_pieces=150, alignment_method="ridge_cv", mask=masker + n_pieces=50, + alignment_method="scaled_orthogonal", + mask=masker, ) template_estim.fit(template_train) +procrustes_template = template_estim.template ############################################################################### # Predict new data for left-out subject # ------------------------------------- -# We use target_train data to fit the transform, indicating it corresponds to -# the contrasts indexed by train_index and predict from this learnt alignment -# contrasts corresponding to template test_index numbers. -# For each train subject and for the template, the AP contrasts are sorted from -# 0, to 53, and then the PA contrasts from 53 to 106. -# - -train_index = range(53) -test_index = range(53, 106) - -# We input the mapping image target_train in a list, we could have input more -# than one subject for which we'd want to predict : [train_1, train_2 ...] +# We predict the contrasts of the left-out subject using the template we just +# created. We use the transform method of the estimator. This method takes the +# left-out subject as input, computes a pairwise alignment with the template +# and returns the aligned data. -prediction_from_template = template_estim.transform( - [target_train], train_index, test_index -) - -# As a baseline prediction, let's just take the average of activations across subjects. - -prediction_from_average = index_img(average_subject, test_index) +predictions_from_template = template_estim.transform(left_out_subject) ############################################################################### # Score the baseline and the prediction # ------------------------------------- # We use a utility scoring function to measure the voxelwise correlation -# between the prediction and the ground truth. That is, for each voxel, we -# measure the correlation between its profile of activation without and with -# alignment, to see if alignment was able to predict a signal more alike the ground truth. -# +# between the images. That is, for each voxel, we measure the correlation between +# its profile of activation without and with alignment, to see if template-based +# alignment was able to improve inter-subject similarity. from fmralign.metrics import score_voxelwise -# Now we use this scoring function to compare the correlation of predictions -# made from group average and from template with the real PA contrasts of sub-07 - average_score = masker.inverse_transform( - score_voxelwise(target_test, prediction_from_average, masker, loss="corr") + score_voxelwise(left_out_subject, average_subject, masker, loss="corr") ) template_score = masker.inverse_transform( score_voxelwise( - target_test, prediction_from_template[0], masker, loss="corr" + predictions_from_template, procrustes_template, masker, loss="corr" ) ) - ############################################################################### # Plotting the measures # --------------------- @@ -167,13 +146,12 @@ baseline_display = plotting.plot_stat_map( average_score, display_mode="z", vmax=1, cut_coords=[-15, -5] ) -baseline_display.title("Group average correlation wt ground truth") +baseline_display.title("Left-out subject correlation with group average") display = plotting.plot_stat_map( template_score, display_mode="z", cut_coords=[-15, -5], vmax=1 ) -display.title("Template-based prediction correlation wt ground truth") +display.title("Aligned subject correlation with Procrustes template") ############################################################################### # We observe that creating a template and aligning a new subject to it yields -# a prediction that is better correlated with the ground truth than just using -# the average activations of subjects. +# better inter-subject similarity than regular euclidean averaging. From e22e7f863a5fe9014a1085b1528c02063375c009 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Tue, 3 Dec 2024 16:09:57 +0100 Subject: [PATCH 19/23] Add test for ParcellationMasker to handle 3D and 4D images with one contrast --- fmralign/tests/test_preprocessing.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/fmralign/tests/test_preprocessing.py b/fmralign/tests/test_preprocessing.py index 570abeb..fb77627 100644 --- a/fmralign/tests/test_preprocessing.py +++ b/fmralign/tests/test_preprocessing.py @@ -208,3 +208,11 @@ def test_get_parcellation_img(): assert np.allclose(data, labels) assert len(np.unique(data)) == n_pieces + +def test_one_contrast(): + """Test that ParcellationMasker handles both 3D and\n + 4D images in the case of one contrast""" + img1, _ = random_niimg((8, 7, 6)) + img2, _ = random_niimg((8, 7, 6, 1)) + pmasker = ParcellationMasker() + pmasker.fit([img1, img2]) From 2379b811231a3c4083b60c492f3d413e4da26ecf Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Tue, 10 Dec 2024 15:33:27 +0100 Subject: [PATCH 20/23] Rebase with main --- fmralign/tests/test_preprocessing.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/fmralign/tests/test_preprocessing.py b/fmralign/tests/test_preprocessing.py index fb77627..0efa96f 100644 --- a/fmralign/tests/test_preprocessing.py +++ b/fmralign/tests/test_preprocessing.py @@ -207,12 +207,3 @@ def test_get_parcellation_img(): assert np.allclose(data, labels) assert len(np.unique(data)) == n_pieces - - -def test_one_contrast(): - """Test that ParcellationMasker handles both 3D and\n - 4D images in the case of one contrast""" - img1, _ = random_niimg((8, 7, 6)) - img2, _ = random_niimg((8, 7, 6, 1)) - pmasker = ParcellationMasker() - pmasker.fit([img1, img2]) From 78ff3b34cf407e396b5095a271ba823aa08ba921 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Tue, 10 Dec 2024 15:38:54 +0100 Subject: [PATCH 21/23] Fix shape assertion in parcellation retrieval test --- fmralign/tests/test_template_alignment.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fmralign/tests/test_template_alignment.py b/fmralign/tests/test_template_alignment.py index 9349371..7b914ec 100644 --- a/fmralign/tests/test_template_alignment.py +++ b/fmralign/tests/test_template_alignment.py @@ -243,7 +243,7 @@ def test_parcellation_retrieval(): assert isinstance(labels, np.ndarray) assert len(np.unique(labels)) == n_pieces assert isinstance(parcellation_image, Nifti1Image) - assert parcellation_image.shape == imgs[0].shape + assert parcellation_image.shape == imgs[0].shape[:-1] def test_parcellation_before_fit(): @@ -256,3 +256,7 @@ def test_parcellation_before_fit(): match="Parcellation has not been computed yet", ): alignment.get_parcellation() + + +if __name__ == "__main__": + test_parcellation_retrieval() From 25b39ee21255ac43e9f408880851da1cb1c15c96 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant <104081777+pbarbarant@users.noreply.github.com> Date: Wed, 18 Dec 2024 00:17:42 +0100 Subject: [PATCH 22/23] Update examples/plot_template_alignment.py Co-authored-by: bthirion --- examples/plot_template_alignment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/plot_template_alignment.py b/examples/plot_template_alignment.py index f865021..bdbdbd2 100644 --- a/examples/plot_template_alignment.py +++ b/examples/plot_template_alignment.py @@ -8,7 +8,7 @@ using Procrustes alignment (hyperalignment) to which we align the target subject, using shared information. We then compare the voxelwise similarity between the target subject and the template to the similarity between the target subject and -the anatomical euclidean average of the source subjects. +the anatomical Euclidean average of the source subjects. We mostly rely on Python common packages and on nilearn to handle functional data in a clean fashion. From 2ba236dceaaddcb666d564c6125da8367ce3f400 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant <104081777+pbarbarant@users.noreply.github.com> Date: Wed, 18 Dec 2024 00:17:47 +0100 Subject: [PATCH 23/23] Update examples/plot_template_alignment.py Co-authored-by: bthirion --- examples/plot_template_alignment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/plot_template_alignment.py b/examples/plot_template_alignment.py index bdbdbd2..b1e6154 100644 --- a/examples/plot_template_alignment.py +++ b/examples/plot_template_alignment.py @@ -66,7 +66,7 @@ for i in range(5): template_train.append(concat_imgs(imgs[i])) -# sub-07 will be our left-out subject. +# sub-07 (that is 5th in the list) will be our left-out subject. # We make a single 4D Niimg from our list of 3D filenames. left_out_subject = concat_imgs(imgs[5])