-
Notifications
You must be signed in to change notification settings - Fork 185
/
Copy pathplot_Hinss2021_classification.py
193 lines (154 loc) · 5.96 KB
/
plot_Hinss2021_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""
================================
Hinss2021 classification example
================================
This example shows how to use the Hinss2021 dataset
with the resting state paradigm.
In this example, we aim to determine the most effective channel selection strategy
for the :class:`moabb.datasets.Hinss2021` dataset.
The pipelines under consideration are:
- `Xdawn`
- Electrode selection based on time epochs data
- Electrode selection based on covariance matrices
"""
# License: BSD (3-clause)
import warnings
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from pyriemann.channelselection import ElectrodeSelection
from pyriemann.estimation import Covariances
from pyriemann.spatialfilters import Xdawn
from pyriemann.tangentspace import TangentSpace
from sklearn.base import TransformerMixin
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import make_pipeline
from moabb import set_log_level
from moabb.datasets import Hinss2021
from moabb.evaluations import CrossSessionEvaluation
from moabb.paradigms import RestingStateToP300Adapter
# Suppressing future and runtime warnings for cleaner output
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)
set_log_level("info")
##############################################################################
# Create util transformer
# ----------------------
#
# Let's create a scikit transformer mixin, that will
# select electrodes based on the covariance information
class EpochSelectChannel(TransformerMixin):
"""Select channels based on covariance information."""
def __init__(self, n_chan, cov_est):
self._chs_idx = None
self.n_chan = n_chan
self.cov_est = cov_est
def fit(self, X, _y=None):
# Get the covariances of the channels for each epoch.
covs = Covariances(estimator=self.cov_est).fit_transform(X)
# Get the average covariance between the channels
m = np.mean(covs, axis=0)
n_feats, _ = m.shape
# Select the `n_chan` channels having the maximum covariances.
all_max = []
for i in range(n_feats):
for j in range(n_feats):
if len(all_max) <= self.n_chan:
all_max.append(m[i, j])
else:
if m[i, j] > max(all_max):
all_max[np.argmin(all_max)] = m[i, j]
indices = []
for v in all_max:
indices.extend(np.argwhere(m == v).flatten())
# We will keep only these channels for the transform step.
indices = np.unique(indices)
self._chs_idx = indices
return self
def transform(self, X):
return X[:, self._chs_idx, :]
##############################################################################
# Initialization Process
# ----------------------
#
# 1) Define the experimental paradigm object (RestingState)
# 2) Load the datasets
# 3) Select a subset of subjects and specific events for analysis
# Here we define the mne events for the RestingState paradigm.
events = dict(easy=2, diff=3)
# The paradigm is adapted to the P300 paradigm.
paradigm = RestingStateToP300Adapter(events=events, tmin=0, tmax=0.5)
# We define a list with the dataset to use
datasets = [Hinss2021()]
# To reduce the computation time in the example, we will only use the
# first two subjects.
n__subjects = 2
title = "Datasets: "
for dataset in datasets:
title = title + " " + dataset.code
dataset.subject_list = dataset.subject_list[:n__subjects]
##############################################################################
# Create Pipelines
# ----------------
#
# Pipelines must be a dict of scikit-learning pipeline transformer.
pipelines = {}
pipelines["Xdawn+Cov+TS+LDA"] = make_pipeline(
Xdawn(nfilter=4), Covariances(estimator="lwf"), TangentSpace(), LDA()
)
pipelines["Cov+ElSel+TS+LDA"] = make_pipeline(
Covariances(estimator="lwf"), ElectrodeSelection(nelec=8), TangentSpace(), LDA()
)
# Pay attention here that the channel selection took place before computing the covariances:
# It is done on time epochs.
pipelines["ElSel+Cov+TS+LDA"] = make_pipeline(
EpochSelectChannel(n_chan=8, cov_est="lwf"),
Covariances(estimator="lwf"),
TangentSpace(),
LDA(),
)
##############################################################################
# Run evaluation
# ----------------
#
# Compare the pipeline using a cross session evaluation.
# Here should be cross-session
evaluation = CrossSessionEvaluation(
paradigm=paradigm,
datasets=datasets,
overwrite=False,
)
results = evaluation.process(pipelines)
###############################################################################
# Here, with the ElSel+Cov+TS+LDA pipeline, we reduce the computation time
# in approximately 8 times to the Cov+ElSel+TS+LDA pipeline.
print("Averaging the session performance:")
print(results.groupby("pipeline").mean("score")[["score", "time"]])
###############################################################################
# Plot Results
# -------------
#
# Here, we plot the results to compare two pipelines
fig, ax = plt.subplots(facecolor="white", figsize=[8, 4])
sns.stripplot(
data=results,
y="score",
x="pipeline",
ax=ax,
jitter=True,
alpha=0.5,
zorder=1,
palette="Set1",
)
sns.pointplot(data=results, y="score", x="pipeline", ax=ax, palette="Set1").set(
title=title
)
ax.set_ylabel("ROC AUC")
ax.set_ylim(0.3, 1)
plt.show()
###############################################################################
# Key Observations:
# -----------------
# - `Xdawn` is not ideal for the resting state paradigm. This is due to its specific design for Event-Related Potential (ERP).
# - Electrode selection strategy based on covariance matrices demonstrates less variability and typically yields better performance.
# - However, this strategy is more time-consuming compared to the simpler electrode selection based on time epoch data.