Source code for steenroder.plotting

import numpy as np
import plotly.graph_objs as gobj
from plotly.colors import DEFAULT_PLOTLY_COLORS as COLS


def _compute_bounds(barcode):
    barcode_no_dims = np.concatenate(barcode)
    posinfinite_mask = np.isposinf(barcode_no_dims)
    neginfinite_mask = np.isneginf(barcode_no_dims)
    max_val = np.max(np.where(posinfinite_mask, -np.inf, barcode_no_dims))
    min_val = np.min(np.where(neginfinite_mask, np.inf, barcode_no_dims))
    parameter_range = max_val - min_val
    extra_space_factor = 0.02
    has_posinfinite_death = np.any(posinfinite_mask[:, 1])
    if has_posinfinite_death:
        posinfinity_val = max_val + 0.1 * parameter_range
        extra_space_factor += 0.1
    else:
        posinfinity_val = None
    has_neginfinite_birth = np.any(neginfinite_mask[:, 0])
    if has_neginfinite_birth:
        neginfinity_val = min_val - 0.1 * parameter_range
        extra_space_factor += 0.1
    else:
        neginfinity_val = None
    extra_space = extra_space_factor * parameter_range
    min_val_display = min_val - extra_space
    max_val_display = max_val + extra_space

    return min_val_display, max_val_display, posinfinity_val, neginfinity_val


[docs]def plot_diagrams(barcode, steenrod_barcode, k=None, kind=None, tex=False, plotly_params=None): """Plot a regular persistence barcode and a Steenrod barcode as diagrams on a common birth-death plane. Parameters ---------- barcode : list of ndarray The persistence barcode to plot. For each dimension ``d``, a 2D array with 2 columns, containing birth-death pairs in degree ``d``, to be used as coordinates in the two-dimensional plot. steenrod_barcode : list of ndarray The (relative) Sq^k-barcode to plot. For each dimension ``d``, a 2D array with 2 columns, containing the birth-death pairs of Steenrod bars in degree ``d``, to be used as coordinates in the two-dimensional plot. k : int or None, optional, default: ``None`` Positive integer defining the cohomology operation Sq^k that was performed to obtain `steenrod_barcode`. Only used for labelling. kind : ``"R"`` | ``"A"`` or None, optional, default: ``None`` Whether the barcodes to be plotted come from absolute or relative cohomology barcodes. tex : bool, optional, default: ``False`` (Experimental!) Whether to display a version of the legend rendered with LaTeX and MathJax. plotly_params : dict or None, optional, default: ``None`` Custom parameters to configure the plotly figure. Allowed keys are ``"traces"`` and ``"layout"``, and the corresponding values should be dictionaries containing keyword arguments as would be fed to the :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. Returns ------- fig : :class:`plotly.graph_objects.Figure` object Figure representing the persistence diagram and Steenrod diagram. """ def _connect_st_label(st_label, h_label, tex): if tex: st_h_label = r"$" + st_label + r" \cap " + h_label + r"$" h_label = r"$" + h_label + r"$" else: st_h_label = st_label + " in " + h_label h_label = h_label return h_label, st_h_label if k is not None: st_label = r"\mathrm{img}" + rf"(Sq^{{{k}}})" if tex else f"im(Sq^{k})" else: st_label = r"\mathrm{img}(Sq^{k})" if tex else "im(Sq^k)" kind = kind.lower() if kind == "a": legend_title = "Absolute Cohomology" h_subscript = r"_{A}" if tex else "" elif kind == "r": legend_title = "Relative Cohomology" h_subscript = r"_{R}" if tex else "" else: legend_title = "Cohomology" h_subscript = r"" if tex else "" homology_dimensions = range(max(len(barcode), len(steenrod_barcode))) min_val_display, max_val_display, posinfinity_val, neginfinity_val = \ _compute_bounds(barcode + steenrod_barcode) fig = gobj.Figure() fig.add_trace(gobj.Scatter( x=[min_val_display, max_val_display], y=[min_val_display, max_val_display], mode="lines", line={"dash": "dash", "width": 1, "color": "black"}, showlegend=False, hoverinfo="none" )) for i, dim in enumerate(homology_dimensions): h_label = (rf"\mathcal{{H}}^{{{dim}}}" if tex else f"H^{dim}") + \ h_subscript h_label, st_h_label = _connect_st_label(st_label, h_label, tex) for label, symbol, ms, bc in ([st_h_label, "diamond", 10, steenrod_barcode], [h_label, "circle", 8, barcode]): subbc = bc[dim].copy() unique, inverse, counts = np.unique( subbc, axis=0, return_inverse=True, return_counts=True ) hovertext = [ f"{tuple(unique[unique_row_index][:2])}" + ( f", multiplicity: {counts[unique_row_index]}" if counts[unique_row_index] > 1 else "" ) for unique_row_index in inverse ] births = subbc[:, 0] if neginfinity_val is not None: births[np.isneginf(births)] = neginfinity_val deaths = subbc[:, 1] if posinfinity_val is not None: deaths[np.isposinf(deaths)] = posinfinity_val fig.add_trace(gobj.Scatter( x=births, y=deaths, name=label, mode="markers", marker_color=COLS[i], marker_symbol=symbol, marker_line={"width": 1}, marker_size=ms, hoverinfo="text", hovertext=hovertext, showlegend=True )) fig.update_layout( width=500, height=500, xaxis1={ "title": "Birth", "side": "bottom", "type": "linear", "range": [min_val_display, max_val_display], "autorange": False, "ticks": "outside", "showline": True, "zeroline": True, "linewidth": 1, "linecolor": "black", "mirror": False, "showexponent": "all", "exponentformat": "e" }, yaxis1={ "title": "Death", "side": "left", "type": "linear", "range": [min_val_display, max_val_display], "autorange": False, "scaleanchor": "x", "scaleratio": 1, "ticks": "outside", "showline": True, "zeroline": True, "linewidth": 1, "linecolor": "black", "mirror": False, "showexponent": "all", "exponentformat": "e" }, plot_bgcolor="white", legend={ "y": 0.01, "x": 0.6, }, legend_title=legend_title if not tex else None, font_family="Serif", font_size=14 ) # Add a horizontal dashed line for points with infinite death if posinfinity_val is not None: fig.add_trace(gobj.Scatter( x=[min_val_display, max_val_display], y=[posinfinity_val, posinfinity_val], mode="lines", line={"dash": "dash", "width": 0.5, "color": "black"}, showlegend=True, name=r"+" + u"\u221E" + 25 * " ", hoverinfo="none" )) # Add a vertical dashed line for points with negative infinite birth if neginfinity_val is not None: fig.add_trace(gobj.Scatter( x=[neginfinity_val, neginfinity_val], y=[min_val_display, max_val_display], mode="lines", line={"dash": "dash", "width": 0.5, "color": "black"}, showlegend=True, name="-" + u"\u221E" + 25 * " ", hoverinfo="none" )) # Update traces and layout according to user input if plotly_params: fig.update_traces(plotly_params.get("traces", None)) fig.update_layout(plotly_params.get("layout", None)) return fig