Skip to content

Commit

Permalink
Merge pull request #116 from Parietal-INRIA/feat/one-parcellation-tem…
Browse files Browse the repository at this point in the history
…plate

Refactor template alignments to keep a fixed parcellation
  • Loading branch information
pbarbarant authored Dec 18, 2024
2 parents a629598 + 2ba236d commit 51b01ff
Show file tree
Hide file tree
Showing 7 changed files with 440 additions and 359 deletions.
80 changes: 29 additions & 51 deletions examples/plot_template_alignment.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# -*- coding: utf-8 -*-

"""
Template-based prediction.
==========================
In this tutorial, we show how to better predict new contrasts for a target
subject using many source subjects corresponding contrasts. For this purpose,
we create a template to which we align the target subject, using shared information.
We then predict new images for the target and compare them to a baseline.
In this tutorial, we show how to improve inter-subject similarity using a template
computed across multiple source subjects. For this purpose, we create a template
using Procrustes alignment (hyperalignment) to which we align the target subject,
using shared information. We then compare the voxelwise similarity between the
target subject and the template to the similarity between the target subject and
the anatomical Euclidean average of the source subjects.
We mostly rely on Python common packages and on nilearn to handle
functional data in a clean fashion.
Expand Down Expand Up @@ -36,7 +37,7 @@
)

###############################################################################
# Definine a masker
# Define a masker
# -----------------
# We define a nilearn masker that will be used to handle relevant data.
# For more information, visit :
Expand Down Expand Up @@ -64,22 +65,17 @@
template_train = []
for i in range(5):
template_train.append(concat_imgs(imgs[i]))
target_train = df[df.subject == "sub-07"][df.acquisition == "ap"].path.values

# For subject sub-07, we split it in two folds:
# - target train: sub-07 AP contrasts, used to learn alignment to template
# - target test: sub-07 PA contrasts, used as a ground truth to score predictions
# We make a single 4D Niimg from our list of 3D filenames
# sub-07 (that is 5th in the list) will be our left-out subject.
# We make a single 4D Niimg from our list of 3D filenames.

target_train = concat_imgs(target_train)
target_test = df[df.subject == "sub-07"][df.acquisition == "pa"].path.values
left_out_subject = concat_imgs(imgs[5])

###############################################################################
# Compute a baseline (average of subjects)
# ----------------------------------------
# We create an image with as many contrasts as any subject representing for
# each contrast the average of all train subjects maps.
#

import numpy as np

Expand All @@ -92,70 +88,53 @@
# ---------------------------------------------
# We define an estimator using the class TemplateAlignment:
# * We align the whole brain through 'multiple' local alignments.
# * These alignments are calculated on a parcellation of the brain in 150 pieces,
# * These alignments are calculated on a parcellation of the brain in 50 pieces,
# this parcellation creates group of functionnally similar voxels.
# * The template is created iteratively, aligning all subjects data into a
# common space, from which the template is inferred and aligning again to this
# new template space.
#

from nilearn.image import index_img

from fmralign.template_alignment import TemplateAlignment

# We use Procrustes/scaled orthogonal alignment method
template_estim = TemplateAlignment(
n_pieces=150, alignment_method="ridge_cv", mask=masker
n_pieces=50,
alignment_method="scaled_orthogonal",
mask=masker,
)
template_estim.fit(template_train)
procrustes_template = template_estim.template

###############################################################################
# Predict new data for left-out subject
# -------------------------------------
# We use target_train data to fit the transform, indicating it corresponds to
# the contrasts indexed by train_index and predict from this learnt alignment
# contrasts corresponding to template test_index numbers.
# For each train subject and for the template, the AP contrasts are sorted from
# 0, to 53, and then the PA contrasts from 53 to 106.
#

train_index = range(53)
test_index = range(53, 106)

# We input the mapping image target_train in a list, we could have input more
# than one subject for which we'd want to predict : [train_1, train_2 ...]
# We predict the contrasts of the left-out subject using the template we just
# created. We use the transform method of the estimator. This method takes the
# left-out subject as input, computes a pairwise alignment with the template
# and returns the aligned data.

prediction_from_template = template_estim.transform(
[target_train], train_index, test_index
)

# As a baseline prediction, let's just take the average of activations across subjects.

prediction_from_average = index_img(average_subject, test_index)
predictions_from_template = template_estim.transform(left_out_subject)

###############################################################################
# Score the baseline and the prediction
# -------------------------------------
# We use a utility scoring function to measure the voxelwise correlation
# between the prediction and the ground truth. That is, for each voxel, we
# measure the correlation between its profile of activation without and with
# alignment, to see if alignment was able to predict a signal more alike the ground truth.
#
# between the images. That is, for each voxel, we measure the correlation between
# its profile of activation without and with alignment, to see if template-based
# alignment was able to improve inter-subject similarity.

from fmralign.metrics import score_voxelwise

# Now we use this scoring function to compare the correlation of predictions
# made from group average and from template with the real PA contrasts of sub-07

average_score = masker.inverse_transform(
score_voxelwise(target_test, prediction_from_average, masker, loss="corr")
score_voxelwise(left_out_subject, average_subject, masker, loss="corr")
)
template_score = masker.inverse_transform(
score_voxelwise(
target_test, prediction_from_template[0], masker, loss="corr"
predictions_from_template, procrustes_template, masker, loss="corr"
)
)


###############################################################################
# Plotting the measures
# ---------------------
Expand All @@ -167,13 +146,12 @@
baseline_display = plotting.plot_stat_map(
average_score, display_mode="z", vmax=1, cut_coords=[-15, -5]
)
baseline_display.title("Group average correlation wt ground truth")
baseline_display.title("Left-out subject correlation with group average")
display = plotting.plot_stat_map(
template_score, display_mode="z", cut_coords=[-15, -5], vmax=1
)
display.title("Template-based prediction correlation wt ground truth")
display.title("Aligned subject correlation with Procrustes template")

###############################################################################
# We observe that creating a template and aligning a new subject to it yields
# a prediction that is better correlated with the ground truth than just using
# the average activations of subjects.
# better inter-subject similarity than regular euclidean averaging.
9 changes: 4 additions & 5 deletions fmralign/pairwise_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,12 @@ def fit(self, X, Y):

return self

def transform(self, X):
def transform(self, img):
"""Predict data from X.
Parameters
----------
X: Niimg-like object
img: Niimg-like object
Source data
Returns
Expand All @@ -255,7 +255,7 @@ def transform(self, X):
"This instance has not been fitted yet. "
"Please call 'fit' before 'transform'."
)
parceled_data_list = self.parcel_masker.transform(X)
parceled_data_list = self.parcel_masker.transform(img)
transformed_img = Parallel(
self.n_jobs, prefer="threads", verbose=self.verbose
)(
Expand All @@ -266,7 +266,6 @@ def transform(self, X):
return transformed_img[0]
else:
return transformed_img
return transformed_img

# Make inherited function harmless
def fit_transform(self):
Expand All @@ -291,7 +290,7 @@ def get_parcellation(self):
if hasattr(self, "parcel_masker"):
check_is_fitted(self)
labels = self.parcel_masker.get_labels()
parcellation_img = self.parcel_masker.get_parcellation()
parcellation_img = self.parcel_masker.get_parcellation_img()
return labels, parcellation_img
else:
raise AttributeError(
Expand Down
2 changes: 1 addition & 1 deletion fmralign/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def get_labels(self):
)
return self.labels

def get_parcellation(self):
def get_parcellation_img(self):
"""Return the parcellation image.
Returns
Expand Down
Loading

0 comments on commit 51b01ff

Please sign in to comment.