From eb83ba324e18eae7e933d3daf658d9b6c039d83a Mon Sep 17 00:00:00 2001 From: ilkilic <10600022+ilkilic@users.noreply.github.com> Date: Fri, 2 Feb 2024 14:38:14 +0100 Subject: [PATCH] Add pagination to plot_recordings (#172) --- bluepyefe/cell.py | 71 ++++++++++++++++++++++++----------------------- 1 file changed, 37 insertions(+), 34 deletions(-) diff --git a/bluepyefe/cell.py b/bluepyefe/cell.py index b2eae03..90d594c 100644 --- a/bluepyefe/cell.py +++ b/bluepyefe/cell.py @@ -224,55 +224,58 @@ def plot_recordings(self, protocol_name, output_dir=None, show=False): recordings = self.get_recordings_by_protocol_name(protocol_name) if not len(recordings): - return None, None + return recordings_amp = [rec.amp for rec in recordings] - recordings = [recordings[k] for k in numpy.argsort(recordings_amp)] + recordings_sorted = [recordings[k] for k in numpy.argsort(recordings_amp)] n_cols = 6 - n_rows = int(2 * numpy.ceil(len(recordings) / n_cols)) + max_plots_per_page = 24 + total_pages = int(numpy.ceil(len(recordings_sorted) / max_plots_per_page)) - fig, axs = plt.subplots( - n_rows, n_cols, - figsize=[3.0 + 3.0 * int(n_cols), 2.5 * n_rows], - squeeze=False - ) - - for i, rec in enumerate(recordings): - - col = i % int(n_cols) - row = 2 * int(i / n_cols) + for page in range(total_pages): + start_idx = page * max_plots_per_page + end_idx = start_idx + max_plots_per_page + page_recordings = recordings_sorted[start_idx:end_idx] - display_ylabel = col == 0 - display_xlabel = row + 1 == axs.shape[0] + n_rows = int(numpy.ceil(len(page_recordings) / n_cols)) * 2 - _, _ = rec.plot( - axis_current=axs[row][col], - axis_voltage=axs[row + 1][col], - display_xlabel=display_xlabel, - display_ylabel=display_ylabel + fig, axs = plt.subplots( + n_rows, n_cols, + figsize=[3.0 * n_cols, 2.5 * n_rows], + squeeze=False ) - fig.suptitle("Cell: {}, Experiment: {}".format(self.name, protocol_name)) + for i, rec in enumerate(page_recordings): + col = i % n_cols + row = (i // n_cols) * 2 - plt.subplots_adjust(wspace=0.53, hspace=0.7) + display_ylabel = col == 0 + display_xlabel = (row // 2) + 1 == n_rows // 2 + + rec.plot( + axis_current=axs[row][col], + axis_voltage=axs[row + 1][col], + display_xlabel=display_xlabel, + display_ylabel=display_ylabel + ) - for ax in axs.flatten(): - if not ax.lines: - ax.set_visible(False) + fig.suptitle(f"Cell: {self.name}, Experiment: {protocol_name}, Page: {page + 1}") + plt.subplots_adjust(wspace=0.53, hspace=0.7) - # Do not use tight-layout, it significantly increases the runtime - plt.margins(0, 0) + for ax in axs.flatten(): + if not ax.lines: + ax.set_visible(False) - if show: - fig.show() + plt.margins(0, 0) - if output_dir is not None: - filename = "{}_{}_recordings.pdf".format(self.name, protocol_name) - dirname = pathlib.Path(output_dir) / self.name - _save_fig(dirname, filename) + if show: + fig.show() - return fig, axs + if output_dir is not None: + filename = f"{self.name}_{protocol_name}_recordings_page_{page + 1}.pdf" + dirname = pathlib.Path(output_dir) / self.name + _save_fig(dirname, filename) def plot_all_recordings(self, output_dir=None, show=False): """Plot all the recordings of the cell.