Skip to content

ridgeplot.dotted_heatmap

for most of the part, this is copied from: https://stackoverflow.com/questions/59381273/heatmap-with-circles-indicating-size-of-population

dotted_heatmap(data, ax, cmap='cividis', circle_size=None)

Plotting dotted heatmap

Example
>>> import matplotlib.pyplot as plt
>>> from ridgeplot.dotted_heatmap import dotted_heatmap
>>> fig = plt.figure()
>>> ax = fig.add_subplot(111)
>>> data = pd.DataFrame(
...    np.random.randn(n, n),
...    index=[f"feature{i}" for i in range(n)],
...    columns=[f"sample{i}" for i in range(n)],
... )
>>> dotted_heatmap(data=data,ax=ax, cmap="viridis")

Parameters:

Name Type Description Default
data DataFrame

data to plot

required
ax Axes

matplotlib ax object

required
cmap str

cmap value, defaults to "cividis"

'cividis'
circle_size Optional[float]

raidus of the circles, if None, we will use relaive sizes, defaults to None

None
Source code in src/ridgeplot/dotted_heatmap.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def dotted_heatmap(
    data: pd.DataFrame,
    ax: matplotlib.axes._axes.Axes,
    cmap: str = "cividis",
    circle_size: Optional[float] = None,
):
    """
    Plotting dotted heatmap

    Example:
        ```
        >>> import matplotlib.pyplot as plt
        >>> from ridgeplot.dotted_heatmap import dotted_heatmap
        >>> fig = plt.figure()
        >>> ax = fig.add_subplot(111)
        >>> data = pd.DataFrame(
        ...    np.random.randn(n, n),
        ...    index=[f"feature{i}" for i in range(n)],
        ...    columns=[f"sample{i}" for i in range(n)],
        ... )
        >>> dotted_heatmap(data=data,ax=ax, cmap="viridis")
        ```

    Args:
        data: data to plot
        ax: matplotlib ax object
        cmap: cmap value, defaults to "cividis"
        circle_size: raidus of the circles,
            if None, we will use relaive sizes, defaults to None
    """
    nrows, ncols = data.shape
    x, y = np.meshgrid(np.arange(ncols), np.arange(nrows))
    radii = data.values / 2 / data.values.max()
    # radius is relative if circle_size is None
    circles = [
        plt.Circle((j, i), radius=circle_size if circle_size is not None else r)
        for r, j, i in zip(radii.flat, x.flat, y.flat)
    ]
    col = PatchCollection(circles, array=data.values.flatten(), cmap=cmap)
    ax.add_collection(col)

    ax.set_xticks(np.arange(ncols))
    ax.set_xticklabels(data.columns)
    ax.set_yticks(np.arange(nrows))
    ax.set_yticklabels(data.index)

    ax.set_xticks(np.arange(ncols + 1) - 0.5, minor=True)
    ax.set_yticks(np.arange(nrows + 1) - 0.5, minor=True)
    ax.grid(which="minor", alpha=0.5, color="white")
    ax.tick_params(left=False, bottom=False)
    for d in ["top", "bottom", "left", "right"]:
        ax.spines[d].set(alpha=0.5)
    plt.colorbar(col)