From 01a1f81c910eb9922954f790bcaf46acc0fead82 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Tue, 14 Mar 2023 14:28:14 +0000 Subject: [PATCH] add our own dendrogram function: - cluster coloring - plotly backend --- tanglegram/__init__.py | 3 +- tanglegram/dend.py | 233 +++++++++++++++++++++++++++++++++++++++++ tanglegram/tangle.py | 73 ------------- 3 files changed, 235 insertions(+), 74 deletions(-) create mode 100644 tanglegram/dend.py diff --git a/tanglegram/__init__.py b/tanglegram/__init__.py index 88524cf..97d2a36 100644 --- a/tanglegram/__init__.py +++ b/tanglegram/__init__.py @@ -15,7 +15,8 @@ # You should have received a copy of the GNU General Public License # along -__version__ = "0.2.0" +__version__ = "0.3.0" from .tangle import * +from .dend import * from . import utils diff --git a/tanglegram/dend.py b/tanglegram/dend.py new file mode 100644 index 0000000..4517bb1 --- /dev/null +++ b/tanglegram/dend.py @@ -0,0 +1,233 @@ +import networkx as nx +import numpy as np + +import matplotlib.colors as mcl +import matplotlib.pyplot as plt + +import scipy as scp +import scipy.cluster as sclust +import scipy.spatial.distance as sdist + +from .utils import linkage_to_graph + + +__all__ = ['dendrogram'] + +# TODO for dendrogram function: +# - add option to label links +# - add option for colouring the leafs (see `row_colors` in seaborn's clustermap) +def dendrogram(Z, + clusters=None, + link_labels=False, + backend='matplotlib', + **kwargs): + """Extends scipy's dendrogram function with extra functionality. + + Currently this adds the option to: + 1. Color links explicitly by clusters instead of the simple + `above_threshold_color` in scipy's dendrogram function + 2. Label the links by the distance. + 3. Use plotly as backend instead of matplotlib. + + Parameters + ---------- + Z : np.ndarray + Linkage. + clusters : list-like, optional + Cluster assignment for each leaf. This will make it so that + the each cluster in the dendrogram will get its own color. + link_labels : bool + Whether to label links with their distance. + backend : "matplotlib" | "plotly" + Which backend to use for plotting. + + **kwargs + Keyword arguments are passed through to scipy's `dendrogram`. + + Returns + ------- + dendrogram : dict + The dendrogram as generated by scipy's dendrogram function. + fig/ax + Depending on the backend either the matplotlib ax or the + plotly figure. + + """ + if backend not in ('matplotlib', 'plotly'): + raise ValueError(f'Unknown backend: "{backend}"') + + if clusters is not None: + G = linkage_to_graph(Z).to_undirected() + leaf_ids = np.arange(0, sclust.hierarchy.num_obs_linkage(Z)) + + clusters = np.asarray(clusters).flatten() + if len(clusters) != len(leaf_ids): + raise ValueError(f'Got {len(clusters)} clusters for {len(leaf_ids)} leafs.') + + # Prepare colors + colors = {} + + # Go over each cluster and find the subgraph that contains it + singletons = [] + for i, cl in enumerate(np.unique(clusters)): + this_ids = leaf_ids[clusters == cl] + if len(this_ids) > 1: + subgraph = [] + for comb in combinations(this_ids, 2): + subgraph += nx.shortest_path(G, comb[0], comb[1]) + subgraph = list(set(subgraph)) + else: + subgraph = this_ids + singletons.append(this_ids[0]) + + # Track colors + c = sclust.hierarchy._link_line_colors[i % len(sclust.hierarchy._link_line_colors)] + colors.update({n: c for n in subgraph}) + + above_threshold_color = kwargs.get('above_threshold_color', 'lightgrey') + DEFAULTS = dict() + DEFAULTS['link_color_func'] = lambda x: colors.get(x, above_threshold_color) + DEFAULTS.update(kwargs) + DEFAULTS['get_leaves'] = True # we need this for later + + if backend == 'matplotlib': + dn = sclust.hierarchy.dendrogram(Z, **DEFAULTS) + + # Unfortunately, the link_color_function ignores clusters that + # consists of a single leaf -> here we change the color by simply + # plotting over it + ax = ax if ax else plt.gca() + + leaves_pos = dict(zip(dn['leaves'], + np.arange(len(dn['leaves'])))) + for s in singletons: + # Find which element corresponds to this singleton + el = [e for i, e in enumerate(Z) if e[0] == s or e[1] == s][0] + height = el[2] + x = (leaves_pos[s] + .5) * 10 + + if kwargs.get('orientation', 'bottom') in ('left', 'right'): + x, height = height, x + + ax.plot([x, x], [0, height], color=colors[s]) + + return dn, ax + else: + dn = sclust.hierarchy.dendrogram(Z, no_plot=True, **DEFAULTS) + fig = dendrogram2plotly(dn, + orientation=kwargs.get('orientation', 'bottom'), + labels=kwargs.get('labels', None), + no_labels=kwargs.get('no_labels', False), + ) + return dn, fig + else: + if backend == 'matplotlib': + dn = sclust.hierarchy.dendrogram(Z, ax=ax, **kwargs) + ax = ax if ax else plt.gca() + return dn, ax + else: + dn = sclust.hierarchy.dendrogram(Z, no_plot=True, **kwargs) + fig = dendrogram2plotly(dn, + orientation=kwargs.get('orientation', 'bottom'), + labels=kwargs.get('labels', None), + no_labels=kwargs.get('no_labels', False), + ) + return dn, fig + + +def dendrogram2plotly(P, orientation='bottom', labels=None, no_labels=False): + """Generate plotly figure from scipy dendrogram. + + Parameters + ---------- + P : dict + The dendrogram from scipy's `dendrogram()` function. + orientation : 'bottom' | 'top' | 'left' | 'right' + Which orientation to plot the dendrogram in. + labels : list-like, optional + If not provided will use indices for labels. + no_labels : bool + If True, won't add labels to plot. + + Returns + ------- + fig : go.Figure + Plotly figure object. + + """ + try: + import plotly.graph_objects as go + except ImportError: + raise ImportError("Please install the plotly library to use this backend:" + "\n pip3 install plotly") + + icoord = np.array(P["icoord"]) + dcoord = np.array(P["dcoord"]) + color_list = np.array(P["color_list"]) + + if labels is None: + labels = np.array(P["ivl"]) + else: + labels = np.asarray(labels)[P["leaves"]] + + # Build traces + # Because we can only give one color per trace we will need to group + # them by colour + traces = [] + for c in np.unique(color_list): + this_ic = icoord[color_list == c] + this_dc = dcoord[color_list == c] + xs = np.array([i for l in this_ic for i in l.tolist() + [None]], dtype=float) + ys = np.array([i for l in this_dc for i in l.tolist() + [None]], dtype=float) + + if orientation == 'top': + ys = ys * -1 + elif orientation == 'left': + xs, ys = ys, xs + elif orientation == 'right': + ys = ys * -1 + xs, ys = ys, xs + + trace = dict( + type="scatter", + x=xs, + y=ys, + mode="lines", + marker=dict(color=mcl.to_hex(c)), + #text=hovertext_label, + #hoverinfo="text", + ) + traces.append(trace) + + # Generate layout + layout = dict(showlegend=False, + autosize=True, + hovermode='closest', + ) + for axis_key in ('xaxis', 'yaxis'): + layout[axis_key] = { + "type": "linear", + "ticks": "outside", + "rangemode": "tozero", + "showticklabels": True, + "zeroline": False, + "showgrid": False, + "showline": True, + } + + if not no_labels: + label_axis = {'left': 'yaxis', + 'right': 'yaxis', + 'top': 'xaxis', + 'bottom': 'xaxis' + }[orientation] + layout[label_axis]["tickvals"] = (np.arange(0, len(labels)) * 10 + 5).tolist() + layout[label_axis]["ticktext"] = labels.tolist() + layout[label_axis]["tickmode"] = "array" + + if orientation == 'top': + layout['xaxis']['side'] = 'top' + elif orientation == 'right': + layout['yaxis']['side'] = 'right' + + return go.Figure(data=traces, layout=layout) diff --git a/tanglegram/tangle.py b/tanglegram/tangle.py index 293cfff..5363fc8 100755 --- a/tanglegram/tangle.py +++ b/tanglegram/tangle.py @@ -26,7 +26,6 @@ from tqdm import tqdm from itertools import combinations -from .utils import linkage_to_graph __all__ = ['tanglegram', 'tanglegram_many', 'entanglement', 'untangle'] @@ -1060,75 +1059,3 @@ def refine(best_linkage1, best_linkage2, min_entang, labels1, labels2, edges, L= improved = min_entang < org_entang return best_linkage1, best_linkage2, min_entang, improved - - -def cluster_dendrogram(Z, clusters, **kwargs): - """Thin wrapper around scipy's dendrogram function. - - This function adds the option to color explicitly by clusters instead of the - simple `above_threshold_color` in scipy's dendrogram function. - - Parameters - ---------- - Z : np.ndarray - Linkage. - clusters : np.ndarray - Cluster assignment for each leaf. - **kwargs - Keyword arguments are passed through to scipy's `dendrogram`. - - - Returns - ------- - dict - - """ - G = linkage_to_graph(Z).to_undirected() - leaf_ids = np.arange(0, sclust.hierarchy.num_obs_linkage(Z)) - - clusters = np.asarray(clusters).flatten() - if len(clusters) != len(leaf_ids): - raise ValueError(f'Got {len(clusters)} clusters for {len(leaf_ids)} leafs.') - - # Prepare colors - colors = {} - - # Go over each cluster and find the subgraph that contains it - singletons = [] - for i, cl in enumerate(np.unique(clusters)): - this_ids = leaf_ids[clusters == cl] - if len(this_ids) > 1: - subgraph = [] - for comb in combinations(this_ids, 2): - subgraph += nx.shortest_path(G, comb[0], comb[1]) - subgraph = list(set(subgraph)) - else: - subgraph = this_ids - singletons.append(this_ids[0]) - - # Track colors - c = sclust.hierarchy._link_line_colors[i % len(sclust.hierarchy._link_line_colors)] - colors.update({n: c for n in subgraph}) - - above_threshold_color = kwargs.get('above_threshold_color', 'lightgrey') - DEFAULTS = dict(link_color_func=lambda x: colors.get(x, above_threshold_color)) - DEFAULTS.update(kwargs) - DEFAULTS['get_leaves'] = True # we need this for later - - dn = sclust.hierarchy.dendrogram(Z, **DEFAULTS) - - # Unfortunately, the link_color_function ignores clusters that consists of a single - # leaf -> we'll try changing the color after the fact - ax = kwargs.get('ax', plt.gca()) - - leaves_pos = dict(zip(dn['leaves'], - np.arange(len(dn['leaves'])))) - for s in singletons: - # Find which element corresponds to this singleton - el = [e for i, e in enumerate(Z) if e[0] == s or e[1] == s][0] - height = el[2] - x = (leaves_pos[s] + .5) * 10 - - ax.plot([x, x], [0, height], color=colors[s]) - - return dn