Skip to content

Commit

Permalink
add our own dendrogram function:
Browse files Browse the repository at this point in the history
- cluster coloring
- plotly backend
  • Loading branch information
schlegelp committed Mar 14, 2023
1 parent 8b389b8 commit 01a1f81
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 74 deletions.
3 changes: 2 additions & 1 deletion tanglegram/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
# You should have received a copy of the GNU General Public License
# along

__version__ = "0.2.0"
__version__ = "0.3.0"

from .tangle import *
from .dend import *
from . import utils
233 changes: 233 additions & 0 deletions tanglegram/dend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
import networkx as nx
import numpy as np

import matplotlib.colors as mcl
import matplotlib.pyplot as plt

import scipy as scp
import scipy.cluster as sclust
import scipy.spatial.distance as sdist

from .utils import linkage_to_graph


__all__ = ['dendrogram']

# TODO for dendrogram function:
# - add option to label links
# - add option for colouring the leafs (see `row_colors` in seaborn's clustermap)
def dendrogram(Z,
clusters=None,
link_labels=False,
backend='matplotlib',
**kwargs):
"""Extends scipy's dendrogram function with extra functionality.
Currently this adds the option to:
1. Color links explicitly by clusters instead of the simple
`above_threshold_color` in scipy's dendrogram function
2. Label the links by the distance.
3. Use plotly as backend instead of matplotlib.
Parameters
----------
Z : np.ndarray
Linkage.
clusters : list-like, optional
Cluster assignment for each leaf. This will make it so that
the each cluster in the dendrogram will get its own color.
link_labels : bool
Whether to label links with their distance.
backend : "matplotlib" | "plotly"
Which backend to use for plotting.
**kwargs
Keyword arguments are passed through to scipy's `dendrogram`.
Returns
-------
dendrogram : dict
The dendrogram as generated by scipy's dendrogram function.
fig/ax
Depending on the backend either the matplotlib ax or the
plotly figure.
"""
if backend not in ('matplotlib', 'plotly'):
raise ValueError(f'Unknown backend: "{backend}"')

if clusters is not None:
G = linkage_to_graph(Z).to_undirected()
leaf_ids = np.arange(0, sclust.hierarchy.num_obs_linkage(Z))

clusters = np.asarray(clusters).flatten()
if len(clusters) != len(leaf_ids):
raise ValueError(f'Got {len(clusters)} clusters for {len(leaf_ids)} leafs.')

# Prepare colors
colors = {}

# Go over each cluster and find the subgraph that contains it
singletons = []
for i, cl in enumerate(np.unique(clusters)):
this_ids = leaf_ids[clusters == cl]
if len(this_ids) > 1:
subgraph = []
for comb in combinations(this_ids, 2):
subgraph += nx.shortest_path(G, comb[0], comb[1])
subgraph = list(set(subgraph))
else:
subgraph = this_ids
singletons.append(this_ids[0])

# Track colors
c = sclust.hierarchy._link_line_colors[i % len(sclust.hierarchy._link_line_colors)]
colors.update({n: c for n in subgraph})

above_threshold_color = kwargs.get('above_threshold_color', 'lightgrey')
DEFAULTS = dict()
DEFAULTS['link_color_func'] = lambda x: colors.get(x, above_threshold_color)
DEFAULTS.update(kwargs)
DEFAULTS['get_leaves'] = True # we need this for later

if backend == 'matplotlib':
dn = sclust.hierarchy.dendrogram(Z, **DEFAULTS)

# Unfortunately, the link_color_function ignores clusters that
# consists of a single leaf -> here we change the color by simply
# plotting over it
ax = ax if ax else plt.gca()

leaves_pos = dict(zip(dn['leaves'],
np.arange(len(dn['leaves']))))
for s in singletons:
# Find which element corresponds to this singleton
el = [e for i, e in enumerate(Z) if e[0] == s or e[1] == s][0]
height = el[2]
x = (leaves_pos[s] + .5) * 10

if kwargs.get('orientation', 'bottom') in ('left', 'right'):
x, height = height, x

ax.plot([x, x], [0, height], color=colors[s])

return dn, ax
else:
dn = sclust.hierarchy.dendrogram(Z, no_plot=True, **DEFAULTS)
fig = dendrogram2plotly(dn,
orientation=kwargs.get('orientation', 'bottom'),
labels=kwargs.get('labels', None),
no_labels=kwargs.get('no_labels', False),
)
return dn, fig
else:
if backend == 'matplotlib':
dn = sclust.hierarchy.dendrogram(Z, ax=ax, **kwargs)
ax = ax if ax else plt.gca()
return dn, ax
else:
dn = sclust.hierarchy.dendrogram(Z, no_plot=True, **kwargs)
fig = dendrogram2plotly(dn,
orientation=kwargs.get('orientation', 'bottom'),
labels=kwargs.get('labels', None),
no_labels=kwargs.get('no_labels', False),
)
return dn, fig


def dendrogram2plotly(P, orientation='bottom', labels=None, no_labels=False):
"""Generate plotly figure from scipy dendrogram.
Parameters
----------
P : dict
The dendrogram from scipy's `dendrogram()` function.
orientation : 'bottom' | 'top' | 'left' | 'right'
Which orientation to plot the dendrogram in.
labels : list-like, optional
If not provided will use indices for labels.
no_labels : bool
If True, won't add labels to plot.
Returns
-------
fig : go.Figure
Plotly figure object.
"""
try:
import plotly.graph_objects as go
except ImportError:
raise ImportError("Please install the plotly library to use this backend:"
"\n pip3 install plotly")

icoord = np.array(P["icoord"])
dcoord = np.array(P["dcoord"])
color_list = np.array(P["color_list"])

if labels is None:
labels = np.array(P["ivl"])
else:
labels = np.asarray(labels)[P["leaves"]]

# Build traces
# Because we can only give one color per trace we will need to group
# them by colour
traces = []
for c in np.unique(color_list):
this_ic = icoord[color_list == c]
this_dc = dcoord[color_list == c]
xs = np.array([i for l in this_ic for i in l.tolist() + [None]], dtype=float)
ys = np.array([i for l in this_dc for i in l.tolist() + [None]], dtype=float)

if orientation == 'top':
ys = ys * -1
elif orientation == 'left':
xs, ys = ys, xs
elif orientation == 'right':
ys = ys * -1
xs, ys = ys, xs

trace = dict(
type="scatter",
x=xs,
y=ys,
mode="lines",
marker=dict(color=mcl.to_hex(c)),
#text=hovertext_label,
#hoverinfo="text",
)
traces.append(trace)

# Generate layout
layout = dict(showlegend=False,
autosize=True,
hovermode='closest',
)
for axis_key in ('xaxis', 'yaxis'):
layout[axis_key] = {
"type": "linear",
"ticks": "outside",
"rangemode": "tozero",
"showticklabels": True,
"zeroline": False,
"showgrid": False,
"showline": True,
}

if not no_labels:
label_axis = {'left': 'yaxis',
'right': 'yaxis',
'top': 'xaxis',
'bottom': 'xaxis'
}[orientation]
layout[label_axis]["tickvals"] = (np.arange(0, len(labels)) * 10 + 5).tolist()
layout[label_axis]["ticktext"] = labels.tolist()
layout[label_axis]["tickmode"] = "array"

if orientation == 'top':
layout['xaxis']['side'] = 'top'
elif orientation == 'right':
layout['yaxis']['side'] = 'right'

return go.Figure(data=traces, layout=layout)
73 changes: 0 additions & 73 deletions tanglegram/tangle.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from tqdm import tqdm
from itertools import combinations

from .utils import linkage_to_graph

__all__ = ['tanglegram', 'tanglegram_many', 'entanglement', 'untangle']

Expand Down Expand Up @@ -1060,75 +1059,3 @@ def refine(best_linkage1, best_linkage2, min_entang, labels1, labels2, edges, L=

improved = min_entang < org_entang
return best_linkage1, best_linkage2, min_entang, improved


def cluster_dendrogram(Z, clusters, **kwargs):
"""Thin wrapper around scipy's dendrogram function.
This function adds the option to color explicitly by clusters instead of the
simple `above_threshold_color` in scipy's dendrogram function.
Parameters
----------
Z : np.ndarray
Linkage.
clusters : np.ndarray
Cluster assignment for each leaf.
**kwargs
Keyword arguments are passed through to scipy's `dendrogram`.
Returns
-------
dict
"""
G = linkage_to_graph(Z).to_undirected()
leaf_ids = np.arange(0, sclust.hierarchy.num_obs_linkage(Z))

clusters = np.asarray(clusters).flatten()
if len(clusters) != len(leaf_ids):
raise ValueError(f'Got {len(clusters)} clusters for {len(leaf_ids)} leafs.')

# Prepare colors
colors = {}

# Go over each cluster and find the subgraph that contains it
singletons = []
for i, cl in enumerate(np.unique(clusters)):
this_ids = leaf_ids[clusters == cl]
if len(this_ids) > 1:
subgraph = []
for comb in combinations(this_ids, 2):
subgraph += nx.shortest_path(G, comb[0], comb[1])
subgraph = list(set(subgraph))
else:
subgraph = this_ids
singletons.append(this_ids[0])

# Track colors
c = sclust.hierarchy._link_line_colors[i % len(sclust.hierarchy._link_line_colors)]
colors.update({n: c for n in subgraph})

above_threshold_color = kwargs.get('above_threshold_color', 'lightgrey')
DEFAULTS = dict(link_color_func=lambda x: colors.get(x, above_threshold_color))
DEFAULTS.update(kwargs)
DEFAULTS['get_leaves'] = True # we need this for later

dn = sclust.hierarchy.dendrogram(Z, **DEFAULTS)

# Unfortunately, the link_color_function ignores clusters that consists of a single
# leaf -> we'll try changing the color after the fact
ax = kwargs.get('ax', plt.gca())

leaves_pos = dict(zip(dn['leaves'],
np.arange(len(dn['leaves']))))
for s in singletons:
# Find which element corresponds to this singleton
el = [e for i, e in enumerate(Z) if e[0] == s or e[1] == s][0]
height = el[2]
x = (leaves_pos[s] + .5) * 10

ax.plot([x, x], [0, height], color=colors[s])

return dn

0 comments on commit 01a1f81

Please sign in to comment.