Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating data splitters for moabb evaluation #624

Open
wants to merge 29 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bacedc5
Creating new splitters and base evaluation
brunaafl Jun 6, 2024
419b2ca
Adding metasplitters
brunaafl Jun 7, 2024
d6e795d
Fixing LazyEvaluation
brunaafl Jun 10, 2024
140670c
Merge branch 'NeuroTechX:develop' into eval_splitters
brunaafl Jun 10, 2024
d724674
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2024
a278026
More optimized version of TimeSeriesSplit
brunaafl Jun 10, 2024
300a6b9
More optimized version of TimeSeriesSplit
brunaafl Jun 10, 2024
7cb79f6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2024
55db70f
Addressing some comments: documentation, types, inconsistencies
brunaafl Jun 10, 2024
2851a15
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2024
c73dd1a
Addressing some comments: optimizing code, adjusts
brunaafl Jun 12, 2024
2b0e735
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2024
cf4b709
Adding examples
brunaafl Jun 26, 2024
177bf65
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 26, 2024
a6b5772
Adding: Pytests for evaluation splitters, and examples for meta split…
brunaafl Aug 15, 2024
26b13d5
Changing: name of TimeSeriesSplit to PseudoOnlineSplit
brunaafl Sep 30, 2024
e6661c4
Merge branch 'develop' into eval_splitters
brunaafl Sep 30, 2024
430e3a8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
698e539
Fixing pre-commit
brunaafl Sep 30, 2024
0fff053
Merge remote-tracking branch 'origin/eval_splitters' into eval_splitters
brunaafl Sep 30, 2024
98d12ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
558d27b
Adding some tests for metasplitters
brunaafl Oct 1, 2024
34ea645
Merge remote-tracking branch 'origin/eval_splitters' into eval_splitters
brunaafl Oct 1, 2024
b435bf8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2024
d8f26a3
Fixing pre-commit
brunaafl Oct 1, 2024
eaf0fb9
Merge remote-tracking branch 'origin/eval_splitters' into eval_splitters
brunaafl Oct 1, 2024
e5159f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2024
516a5e8
Fixing pre-commit
brunaafl Oct 1, 2024
b29ecd2
Merge remote-tracking branch 'origin/eval_splitters' into eval_splitters
brunaafl Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions moabb/evaluations/metasplitters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import numpy as np
from sklearn.model_selection import (
BaseCrossValidator,
LeaveOneGroupOut,
StratifiedKFold,
StratifiedShuffleSplit,
)


class OfflineSplit(BaseCrossValidator):

def __init__(self, n_folds):
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
self.n_folds = n_folds
brunaafl marked this conversation as resolved.
Show resolved Hide resolved

def get_n_splits(self, metadata):
subjects = len(metadata.subject.unique())
sessions = len(metadata.session.unique())
return subjects * sessions
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
brunaafl marked this conversation as resolved.
Show resolved Hide resolved

def split(self, X, y, metadata, **kwargs):
brunaafl marked this conversation as resolved.
Show resolved Hide resolved

subjects = metadata.subject
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
brunaafl marked this conversation as resolved.
Show resolved Hide resolved

for subject in subjects.unique():
X_, y_, meta_ = (
X[subjects == subject],
y[subjects == subject],
metadata[subjects == subject],
)
sessions = meta_.session.values
brunaafl marked this conversation as resolved.
Show resolved Hide resolved

for session in sessions:
ix_test = np.nonzero(sessions == session)[0]
brunaafl marked this conversation as resolved.
Show resolved Hide resolved

yield ix_test


class TimeSeriesSplit(BaseCrossValidator):
brunaafl marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, n_folds):
self.n_folds = n_folds

def get_n_splits(self, metadata):

brunaafl marked this conversation as resolved.
Show resolved Hide resolved
runs = metadata.run.unique()

if len(runs) > 1:
splits = len(runs)
else:
splits = self.n_folds

return splits

def split(self, X, y, metadata, **kwargs):

runs = metadata.run.unique()

# If runs.unique != 1
if len(runs) > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if len(runs)>1 then calib_size is ignored.

Is this a desired behaviour @brunaafl @bruAristimunha ?

If yes, this must be very clear in the doc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was what I intended, but I don't know if it is the best solution, thought.

cv = LeaveOneGroupOut()
else:
cv = StratifiedKFold(n_splits=self.n_folds, shuffle=False)
# Else, do a k-fold?

subjects = metadata.subject

for subject in subjects.unique():
X_, y_, meta_ = (
X[subjects == subject],
y[subjects == subject],
metadata[subjects == subject],
)
sessions = meta_.session.values

for session in sessions.unique():

X_s, y_s, meta_s = (
X[sessions == session],
y[subjects == session],
metadata[subjects == session],
)
runs = meta_s.run.values

if len(runs) > 1:
param = (X_s, y_s, runs)
else:
param = (X_s, y_s)

for ix_test, ix_calib in cv.split(*param):
yield ix_test, ix_calib
break


class SamplerSplit(BaseCrossValidator):

def __init__(self, test_size, n_perms, data_size=None):
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
self.data_size = data_size
self.test_size = test_size
self.n_perms = n_perms

self.split = IndividualSamplerSplit(
self.test_size, self.n_perms, data_size=self.data_size
)

def get_n_splits(self, y=None):
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
return self.n_perms[0] * len(self.split.get_data_size_subsets(y))

def split(self, X, y, metadata, **kwargs):
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
subjects = metadata.subject.values
split = self.split

for subject in np.unique(subjects):
X_, y_, meta_ = (
X[subjects == subject],
y[subjects == subject],
metadata[subjects == subject],
)

yield split.split(X_, y_, meta_)


class IndividualSamplerSplit(BaseCrossValidator):

def __init__(self, test_size, n_perms, data_size=None):
self.data_size = data_size
self.test_size = test_size
self.n_perms = n_perms

def get_n_splits(self, y=None):
return self.n_perms[0] * len(self.get_data_size_subsets(y))

def get_data_size_subsets(self, y):
if self.data_size is None:
raise ValueError(
"Cannot create data subsets without valid policy for data_size."
)
if self.data_size["policy"] == "ratio":
vals = np.array(self.data_size["value"])
if np.any(vals < 0) or np.any(vals > 1):
raise ValueError("Data subset ratios must be in range [0, 1]")
upto = np.ceil(vals * len(y)).astype(int)
indices = [np.array(range(i)) for i in upto]
elif self.data_size["policy"] == "per_class":
classwise_indices = dict()
n_smallest_class = np.inf
for cl in np.unique(y):
cl_i = np.where(cl == y)[0]
classwise_indices[cl] = cl_i
n_smallest_class = (
len(cl_i) if len(cl_i) < n_smallest_class else n_smallest_class
)
indices = []
for ds in self.data_size["value"]:
if ds > n_smallest_class:
raise ValueError(
f"Smallest class has {n_smallest_class} samples. "
f"Desired samples per class {ds} is too large."
)
indices.append(
np.concatenate(
[classwise_indices[cl][:ds] for cl in classwise_indices]
)
)
else:
raise ValueError(f"Unknown policy {self.data_size['policy']}")
return indices

def split(self, X, y, metadata, **kwargs):

sessions = metadata.session.unique()

cv = StratifiedShuffleSplit(n_splits=self.n_perms[0], test_size=self.test_size)

for session in np.unique(sessions):
X_, y_, meta_ = (
X[sessions == session],
y[sessions == session],
metadata[sessions == session],
)

for ix_train, ix_test in cv.split(X_, y_):

y_split = y_[ix_train]
data_size_steps = self.get_data_size_subsets(y_split)
for subset_indices in data_size_steps:
ix_train = ix_train[subset_indices]
yield ix_train, ix_test
126 changes: 126 additions & 0 deletions moabb/evaluations/splitters.py
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import numpy as np
from sklearn.model_selection import (
BaseCrossValidator,
GroupKFold,
LeaveOneGroupOut,
StratifiedKFold,
)


class WithinSubjectSplitter(BaseCrossValidator):

def __init__(self, n_folds):
self.n_folds = n_folds

def get_n_splits(self, metadata):
sessions_subjects = len(metadata.groupby(["subject", "session"]).first())
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
return self.n_folds * sessions_subjects

def split(self, X, y, metadata, **kwargs):

subjects = metadata.subject.values

split = IndividualWithinSubjectSplitter(self.n_folds)

for subject in np.unique(subjects):

X_, y_, meta_ = (
X[subjects == subject],
y[subjects == subject],
metadata[subjects == subject],
)

yield split.split(X_, y_, meta_)


class IndividualWithinSubjectSplitter(BaseCrossValidator):

def __init__(self, n_folds):
self.n_folds = n_folds

def get_n_splits(self, metadata):
return self.n_folds

def split(self, X, y, metadata, **kwargs):

brunaafl marked this conversation as resolved.
Show resolved Hide resolved
sessions = metadata.subject.values
brunaafl marked this conversation as resolved.
Show resolved Hide resolved

cv = StratifiedKFold(self.n_folds, **kwargs)

for session in np.unique(sessions):
X_, y_, meta_ = (
X[sessions == session],
y[sessions == session],
metadata[sessions == session],
)
brunaafl marked this conversation as resolved.
Show resolved Hide resolved

for ix_train, ix_test in cv.split(X_, y_):

yield ix_train, ix_test


class CrossSessionSplitter(BaseCrossValidator):

def __init__(self, n_folds):
self.n_folds = n_folds

def get_n_splits(self, metadata):
sessions_subjects = len(metadata.groupby(["subject", "session"]).first())
return sessions_subjects

def split(self, X, y, metadata, **kwargs):

subjects = metadata.subject.values
split = IndividualCrossSessionSplitter(self.n_folds)

for subject in np.unique(subjects):
X_, y_, meta_ = (
X[subjects == subject],
y[subjects == subject],
metadata[subjects == subject],
)

yield split.split(X_, y_, meta_)


class IndividualCrossSessionSplitter(BaseCrossValidator):

def __init__(self, n_folds):
self.n_folds = n_folds

def get_n_splits(self, metadata):
sessions = metadata.session.values
return np.unique(sessions)

def split(self, X, y, metadata, **kwargs):

cv = LeaveOneGroupOut()

sessions = metadata.session.values

for ix_train, ix_test in cv.split(X, y, groups=sessions):

yield ix_train, ix_test


class CrossSubjectSplitter(BaseCrossValidator):

def __init__(self, n_groups=None):
self.n_groups = n_groups

def get_n_splits(self, dataset=None):
return self.n_groups

def split(self, X, y, metadata, **kwargs):
brunaafl marked this conversation as resolved.
Show resolved Hide resolved

groups = metadata.subject.values

# Define split
if self.n_groups is None:
cv = LeaveOneGroupOut()
else:
cv = GroupKFold(n_splits=self.n_groups)

for ix_train, ix_test in cv.split(metadata, groups=groups):

yield ix_train, ix_test
Loading
Loading