Skip to content

Commit

Permalink
Merge pull request #1 from aertslab/plotting
Browse files Browse the repository at this point in the history
Plotting + bugfixes + small features
  • Loading branch information
LukasMahieu authored Jun 19, 2024
2 parents 103c370 + 2e4f458 commit 9a71dc8
Show file tree
Hide file tree
Showing 31 changed files with 1,098 additions and 30 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/img/examples/bar_region.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/img/examples/hist_distribution.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
53 changes: 52 additions & 1 deletion docs/api/plotting.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,55 @@ Plotting description
:toctree: _autosummary
pl.contribution_scores
```
```

## Bar plots

```{eval-rst}
.. autosummary::
:toctree: _autosummary
pl.bar.region
pl.bar.region_predictions
pl.bar.normalization_weights
```

## Distribution plots

```{eval-rst}
.. autosummary::
:toctree: _autosummary
pl.hist.distribution
```

## Heatmap

Correlations

```{eval-rst}
.. autosummary::
:toctree: _autosummary
pl.heatmap.correlations_self
pl.heatmap.correlations_predictions
```

## Scatter plots

```{eval-rst}
.. autosummary::
:toctree: _autosummary
pl.scatter.class_density
```

## Utility functions

```{eval-rst}
.. autosummary::
:toctree: _autosummary
pl.render_plot
```

2 changes: 0 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@
"python": ("https://docs.python.org/3", None),
"anndata": ("https://anndata.readthedocs.io/en/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"tensorflow": ("https://www.tensorflow.org/api_docs/python", None),
"keras": ("https://keras.io/api/", None),
}

# List of patterns, relative to source directory, that match files and
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"loguru",
"logomaker",
"pybigtools",
"seaborn"
]

[project.optional-dependencies]
Expand Down
4 changes: 2 additions & 2 deletions src/crested/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

__version__ = version("crested")

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["AUTOGRAPH_VERBOSITY"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["AUTOGRAPH_VERBOSITY"] = "0"

# Setup loguru logging
setup_logging(log_level="INFO", log_file=None)
4 changes: 2 additions & 2 deletions src/crested/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def import_topics(
will be included.
Topics should be named after the topics file name without the extension.
chromsizes_file
File path of the chromsizes file.
File path of the chromsizes file. Used for checking if the new regions are within the chromosome boundaries.
remove_empty_regions
Remove regions that are not open in any topic.
compress
Expand Down Expand Up @@ -302,7 +302,7 @@ def import_bigwigs(
regions_file
File name of the consensus regions BED file.
chromsizes_file
File name of the chromsizes file.
File name of the chromsizes file. Used for checking if the new regions are within the chromosome boundaries.
target
Target value to extract from bigwigs. Can be 'mean', 'max', 'count', or 'logcount'
target_region_width
Expand Down
2 changes: 2 additions & 0 deletions src/crested/pl/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from . import bar, heatmap, hist, scatter
from ._contribution_scores import contribution_scores
from ._utils import render_plot
2 changes: 2 additions & 0 deletions src/crested/pl/_contribution_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def contribution_scores(
>>> seqs_one_hot = np.random.randint(0, 2, (1, 100, 4))
>>> class_names = ["class1"]
>>> crested.pl.contribution_scores(scores, seqs_one_hot, class_names)
.. image:: ../../../docs/_static/img/examples/contribution_scores.png
"""
# Center and zoom
_check_contrib_params(zoom_n_bases, scores)
Expand Down
55 changes: 54 additions & 1 deletion src/crested/pl/_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,68 @@
"""Utility functions for plotting in CREsted."""

from __future__ import annotations

import logomaker
import matplotlib.pyplot as plt
import numpy as np


def render_plot(
fig,
width: int = 8,
height: int = 8,
title: str | None = None,
xlabel: str | None = None,
ylabel: str | None = None,
save_path: str | None = None,
) -> None:
"""
Render a plot with customization options.
Note
----
This function should never be called directly. Rather, the other plotting functions call this function.
Parameters
----------
fig
The figure object to render.
width
Width of the plot (inches).
height
Height of the plot (inches).
title
Title of the plot.
xlabel
Label for the X-axis.
ylabel
Label for the Y-axis.
fig_path
Optional path to save the figure. If None, the figure is displayed but not saved.
"""
fig.set_size_inches(width, height)
if title:
fig.suptitle(title)
for ax in fig.axes:
if xlabel:
ax.set_xlabel(xlabel)
if ylabel:
ax.set_ylabel(ylabel)

plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.show()


def grad_times_input_to_df(x, grad, alphabet="ACGT"):
"""Generate pandas dataframe for saliency plot based on grad x inputs"""
x_index = np.argmax(np.squeeze(x), axis=1)
grad = np.squeeze(grad)
L, A = grad.shape

seq = ""
saliency = np.zeros((L))
saliency = np.zeros(L)
for i in range(L):
seq += alphabet[x_index[i]]
saliency[i] = grad[i, x_index[i]]
Expand Down
2 changes: 2 additions & 0 deletions src/crested/pl/bar/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from ._normalization_weights import normalization_weights
from ._region import region, region_predictions
65 changes: 65 additions & 0 deletions src/crested/pl/bar/_normalization_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Bar plot of normalization weights."""

from __future__ import annotations

import matplotlib.pyplot as plt
from anndata import AnnData

from crested._logging import log_and_raise
from crested.pl._utils import render_plot


def normalization_weights(adata: AnnData, **kwargs):
"""
Plot the distribution of normalization scaling factors per cell type.
Parameters
----------
adata
AnnData object containing the normalization weights in `obsm["weights"]`.
kwargs
Additional arguments passed to :func:`~crested.pl.render_plot` to
control the final plot output. Please see :func:`~crested.pl.render_plot`
for details.
See Also
--------
crested.pl.render_plot
Example
-------
>>> crested.pl.bar.normalization_weights(
... adata,
... xlabel="Cell type",
... ylabel="Scaling factor",
... width=20,
... height=3,
... title="Normalization scaling factors",
... )
.. image:: ../../../docs/_static/img/examples/bar_normalization_weights.png
"""

@log_and_raise(ValueError)
def _check_input_params():
if "weights" not in adata.obsm:
raise ValueError("Normalization weights not found in adata.obsm['weights']")

_check_input_params()

weights = adata.obsm["weights"]
classes = list(adata.obs_names)

fig, ax = plt.subplots()
ax.bar(classes, weights)

# default plot size
default_width = 20
default_height = 3

if "width" not in kwargs:
kwargs["width"] = default_width
if "height" not in kwargs:
kwargs["height"] = default_height

render_plot(fig, **kwargs)
Loading

0 comments on commit 9a71dc8

Please sign in to comment.