Skip to content

Commit

Permalink
Refactor variable names from 'pmasker' to 'parcel_masker'
Browse files Browse the repository at this point in the history
  • Loading branch information
Pierre-Louis Barbarant committed Dec 5, 2024
1 parent c2a0841 commit fd0c2f3
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 74 deletions.
22 changes: 12 additions & 10 deletions fmralign/pairwise_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def fit(self, X, Y):
-------
self
"""
self.pmasker = ParcellationMasker(
self.parcel_masker = ParcellationMasker(
n_pieces=self.n_pieces,
clustering=self.clustering,
mask=self.mask,
Expand All @@ -218,11 +218,13 @@ def fit(self, X, Y):
verbose=self.verbose,
)

parceled_source, parceled_target = self.pmasker.fit_transform([X, Y])
self.masker = self.pmasker.masker_
self.mask = self.pmasker.masker_.mask_img_
self.labels_ = self.pmasker.labels
self.n_pieces = self.pmasker.n_pieces
parceled_source, parceled_target = self.parcel_masker.fit_transform(
[X, Y]
)
self.masker = self.parcel_masker.masker_
self.mask = self.parcel_masker.masker_.mask_img_
self.labels_ = self.parcel_masker.labels
self.n_pieces = self.parcel_masker.n_pieces

self.fit_ = Parallel(
self.n_jobs, prefer="threads", verbose=self.verbose
Expand Down Expand Up @@ -253,7 +255,7 @@ def transform(self, X):
"This instance has not been fitted yet. "
"Please call 'fit' before 'transform'."
)
parceled_data_list = self.pmasker.transform(X)
parceled_data_list = self.parcel_masker.transform(X)
transformed_img = Parallel(
self.n_jobs, prefer="threads", verbose=self.verbose
)(
Expand Down Expand Up @@ -286,10 +288,10 @@ def get_parcellation(self):
parcellation_img: Niimg-like object
Parcellation image.
"""
if hasattr(self, "pmasker"):
if hasattr(self, "parcel_masker"):
check_is_fitted(self)
labels = self.pmasker.get_labels()
parcellation_img = self.pmasker.get_parcellation()
labels = self.parcel_masker.get_labels()
parcellation_img = self.parcel_masker.get_parcellation()
return labels, parcellation_img
else:
raise AttributeError(
Expand Down
20 changes: 11 additions & 9 deletions fmralign/srm.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def fit(self, imgs):
Length : n_samples
"""
self.pmasker = ParcellationMasker(
self.parcel_masker = ParcellationMasker(
n_pieces=self.n_pieces,
clustering=self.clustering,
mask=self.mask,
Expand All @@ -217,11 +217,11 @@ def fit(self, imgs):
n_jobs=self.n_jobs,
verbose=self.verbose,
)
parceled_data = self.pmasker.fit_transform(imgs)
self.masker_ = self.pmasker.masker_
self.mask = self.pmasker.masker_.mask_img_
self.labels_ = self.pmasker.labels
self.n_pieces = self.pmasker.n_pieces
parceled_data = self.parcel_masker.fit_transform(imgs)
self.masker_ = self.parcel_masker.masker_
self.mask = self.parcel_masker.masker_.mask_img_
self.labels_ = self.parcel_masker.labels
self.n_pieces = self.parcel_masker.n_pieces

outputs = Parallel(
n_jobs=self.n_jobs, prefer="threads", verbose=self.verbose
Expand All @@ -233,7 +233,7 @@ def fit(self, imgs):
for i in range(self.n_pieces)
)

self.labels_ = self.pmasker.labels
self.labels_ = self.parcel_masker.labels
self.fit_ = [output[0] for output in outputs]
self.reduced_sr = [output[1] for output in outputs]
return self
Expand All @@ -242,7 +242,9 @@ def add_subjects(self, imgs):
"""Add subject without recalculating SR"""
for i in range(self.n_pieces):
self.fit_[i]
X_i = _get_parcel_across_subjects(self.pmasker.transform(imgs), i)
X_i = _get_parcel_across_subjects(
self.parcel_masker.transform(imgs), i
)
srm = self.fit_[i]
srm.add_subjects(X_i, self.reduced_sr[i])
return self
Expand All @@ -262,7 +264,7 @@ def transform(self, imgs):

n_comps = self.srm.n_components
aligned_imgs = []
imgs_prep = self.pmasker.transform(imgs)
imgs_prep = self.parcel_masker.transform(imgs)
bag_align = []
for i, piece_srm in enumerate(self.fit_):
X_i = [parceled_data[i].T for parceled_data in imgs_prep]
Expand Down
112 changes: 57 additions & 55 deletions fmralign/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,56 +10,56 @@

def test_init_default_params():
"""Test that ParcellationMasker initializes with default parameters"""
pmasker = ParcellationMasker()
assert pmasker.n_pieces == 1
assert pmasker.clustering == "kmeans"
assert pmasker.mask is None
assert pmasker.smoothing_fwhm is None
assert pmasker.standardize is False
assert pmasker.detrend is False
assert pmasker.labels is None
parcel_masker = ParcellationMasker()
assert parcel_masker.n_pieces == 1
assert parcel_masker.clustering == "kmeans"
assert parcel_masker.mask is None
assert parcel_masker.smoothing_fwhm is None
assert parcel_masker.standardize is False
assert parcel_masker.detrend is False
assert parcel_masker.labels is None


def test_init_custom_params():
"""Test that ParcellationMasker initializes with custom parameters"""
pmasker = ParcellationMasker(
parcel_masker = ParcellationMasker(
n_pieces=2, clustering="ward", standardize=True, detrend=True, n_jobs=2
)
assert pmasker.n_pieces == 2
assert pmasker.clustering == "ward"
assert pmasker.standardize is True
assert pmasker.detrend is True
assert pmasker.n_jobs == 2
assert parcel_masker.n_pieces == 2
assert parcel_masker.clustering == "ward"
assert parcel_masker.standardize is True
assert parcel_masker.detrend is True
assert parcel_masker.n_jobs == 2


def test_fit_single_image():
"""Test that ParcellationMasker fits a single image"""
img, _ = random_niimg((8, 7, 6))
pmasker = ParcellationMasker(n_pieces=2)
fitted_pmasker = pmasker.fit(img)
parcel_masker = ParcellationMasker(n_pieces=2)
fitted_parcel_masker = parcel_masker.fit(img)

assert hasattr(fitted_pmasker, "masker_")
assert fitted_pmasker.labels is not None
assert isinstance(fitted_pmasker.labels, np.ndarray)
assert len(np.unique(fitted_pmasker.labels)) == 2 # n_pieces=2
assert hasattr(fitted_parcel_masker, "masker_")
assert fitted_parcel_masker.labels is not None
assert isinstance(fitted_parcel_masker.labels, np.ndarray)
assert len(np.unique(fitted_parcel_masker.labels)) == 2 # n_pieces=2


def test_fit_multiple_images():
"""Test that ParcellationMasker fits multiple images"""
imgs = [random_niimg((8, 7, 6))[0]] * 3
pmasker = ParcellationMasker(n_pieces=2)
fitted_pmasker = pmasker.fit(imgs)
parcel_masker = ParcellationMasker(n_pieces=2)
fitted_parcel_masker = parcel_masker.fit(imgs)

assert hasattr(fitted_pmasker, "masker_")
assert fitted_pmasker.labels is not None
assert hasattr(fitted_parcel_masker, "masker_")
assert fitted_parcel_masker.labels is not None


def test_transform_single_image():
"""Test that ParcellationMasker transforms a single image"""
img, _ = random_niimg((8, 7, 6))
pmasker = ParcellationMasker(n_pieces=2)
pmasker.fit(img)
transformed_data = pmasker.transform(img)
parcel_masker = ParcellationMasker(n_pieces=2)
parcel_masker.fit(img)
transformed_data = parcel_masker.transform(img)

assert isinstance(transformed_data, list)
assert len(transformed_data) == 1
Expand All @@ -69,9 +69,9 @@ def test_transform_single_image():
def test_transform_multiple_images():
"""Test that ParcellationMasker transforms multiple images"""
imgs = [random_niimg((8, 7, 6))[0]] * 3
pmasker = ParcellationMasker(n_pieces=2)
pmasker.fit(imgs)
transformed_data = pmasker.transform(imgs)
parcel_masker = ParcellationMasker(n_pieces=2)
parcel_masker.fit(imgs)
transformed_data = parcel_masker.transform(imgs)

assert isinstance(transformed_data, list)
assert len(transformed_data) == 3
Expand All @@ -83,16 +83,16 @@ def test_transform_multiple_images():

def test_get_labels_before_fit():
"""Test that ParcellationMasker raises ValueError if get_labels is called before fit"""
pmasker = ParcellationMasker()
parcel_masker = ParcellationMasker()
with pytest.raises(ValueError, match="Labels have not been computed yet"):
pmasker.get_labels()
parcel_masker.get_labels()


def test_get_labels_after_fit():
img, _ = random_niimg((8, 7, 6))
pmasker = ParcellationMasker(n_pieces=2)
pmasker.fit(img)
labels = pmasker.get_labels()
parcel_masker = ParcellationMasker(n_pieces=2)
parcel_masker.fit(img)
labels = parcel_masker.get_labels()

assert labels is not None
assert isinstance(labels, np.ndarray)
Expand All @@ -108,13 +108,13 @@ def test_different_shaped_images():
different_img = Nifti1Image(different_data, np.eye(4))

imgs = [img, different_img]
pmasker = ParcellationMasker()
parcel_masker = ParcellationMasker()

with pytest.raises(
NotImplementedError,
match="fmralign does not support images of different shapes",
):
pmasker.fit(imgs)
parcel_masker.fit(imgs)


def test_clustering_with_mask():
Expand All @@ -124,20 +124,22 @@ def test_clustering_with_mask():
clustering_data[5:, :, :] = 0
clustering_img = Nifti1Image(clustering_data, np.eye(4))
img, dummy_mask = random_niimg((8, 7, 6))
pmasker = ParcellationMasker(clustering=clustering_img, mask=dummy_mask)
parcel_masker = ParcellationMasker(
clustering=clustering_img, mask=dummy_mask
)
with pytest.warns(
UserWarning, match="Mask used was bigger than clustering provided"
):
pmasker.fit(img)
parcel_masker.fit(img)


def test_memory_caching(tmp_path):
"""Test that ParcellationMasker can use joblib memory caching"""
img, _ = random_niimg((8, 7, 6))
# Test that memory caching works
memory = Memory(location=str(tmp_path), verbose=0)
pmasker = ParcellationMasker(memory=memory, memory_level=1)
pmasker.fit(img)
parcel_masker = ParcellationMasker(memory=memory, memory_level=1)
parcel_masker.fit(img)

# Check that cache directory is not empty
cache_files = list(tmp_path.glob("joblib/*"))
Expand All @@ -148,19 +150,19 @@ def test_memory_caching(tmp_path):
def test_parallel_processing(n_jobs):
"""Test parallel processing with joblib"""
imgs = [random_niimg((8, 7, 6))[0]] * 3
pmasker = ParcellationMasker(n_pieces=2, n_jobs=n_jobs)
pmasker.fit(imgs)
transformed_data = pmasker.transform(imgs)
parcel_masker = ParcellationMasker(n_pieces=2, n_jobs=n_jobs)
parcel_masker.fit(imgs)
transformed_data = parcel_masker.transform(imgs)

assert len(transformed_data) == 3


def test_smoothing_parameter():
"""Test that ParcellationMasker applies smoothing"""
img, _ = random_niimg((8, 7, 6))
pmasker = ParcellationMasker(smoothing_fwhm=4.0)
pmasker.fit(img)
transformed_data = pmasker.transform(img)
parcel_masker = ParcellationMasker(smoothing_fwhm=4.0)
parcel_masker.fit(img)
transformed_data = parcel_masker.transform(img)

assert isinstance(transformed_data, list)
assert len(transformed_data) == 1
Expand All @@ -169,9 +171,9 @@ def test_smoothing_parameter():
def test_standardization():
"""Test that ParcellationMasker standardizes data"""
img, _ = random_niimg((8, 7, 6, 20))
pmasker = ParcellationMasker(standardize=True)
pmasker.fit(img)
transformed_data = pmasker.transform(img)
parcel_masker = ParcellationMasker(standardize=True)
parcel_masker.fit(img)
transformed_data = parcel_masker.transform(img)

# Check if data is standardized (mean ≈ 0, std ≈ 1)
data_array = transformed_data[0].data
Expand All @@ -183,15 +185,15 @@ def test_get_parcellation():
"""Test that ParcellationMasker returns the parcellation mask"""
n_pieces = 2
img, _ = random_niimg((8, 7, 6))
pmasker = ParcellationMasker(n_pieces=n_pieces)
pmasker.fit(img)
parcellation_img = pmasker.get_parcellation()
labels = pmasker.get_labels()
parcel_masker = ParcellationMasker(n_pieces=n_pieces)
parcel_masker.fit(img)
parcellation_img = parcel_masker.get_parcellation()
labels = parcel_masker.get_labels()

assert isinstance(parcellation_img, Nifti1Image)
assert parcellation_img.shape == img.shape

masker = pmasker.masker_
masker = parcel_masker.masker_
data = masker.transform(parcellation_img)

assert np.allclose(data, labels)
Expand Down

0 comments on commit fd0c2f3

Please sign in to comment.