Source code for cell2location.plt.plot_spatial

# +
import warnings

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from matplotlib.gridspec import GridSpec


def get_rgb_function(cmap, min_value, max_value):
    r"""Generate a function to map continous values to RGB values using colormap between min_value & max_value."""

    if min_value > max_value:
        raise ValueError("Max_value should be greater or than min_value.")

    if min_value == max_value:
        warnings.warn(
            "Max_color is equal to min_color. It might be because of the data or bad parameter choice. "
            "If you are using plot_contours function try increasing max_color_quantile parameter and"
            "removing cell types with all zero values."
        )

        def func_equal(x):
            factor = 0 if max_value == 0 else 0.5
            return cmap(np.ones_like(x) * factor)

        return func_equal

    def func(x):
        return cmap((np.clip(x, min_value, max_value) - min_value) / (max_value - min_value))

    return func


def rgb_to_ryb(rgb):
    """
    Converts colours from RGB colorspace to RYB

    Parameters
    ----------

    rgb
        numpy array Nx3

    Returns
    -------
    Numpy array Nx3
    """
    rgb = np.array(rgb)
    if len(rgb.shape) == 1:
        rgb = rgb[np.newaxis, :]

    white = rgb.min(axis=1)
    black = (1 - rgb).min(axis=1)
    rgb = rgb - white[:, np.newaxis]

    yellow = rgb[:, :2].min(axis=1)
    ryb = np.zeros_like(rgb)
    ryb[:, 0] = rgb[:, 0] - yellow
    ryb[:, 1] = (yellow + rgb[:, 1]) / 2
    ryb[:, 2] = (rgb[:, 2] + rgb[:, 1] - yellow) / 2

    mask = ~(ryb == 0).all(axis=1)
    if mask.any():
        norm = ryb[mask].max(axis=1) / rgb[mask].max(axis=1)
        ryb[mask] = ryb[mask] / norm[:, np.newaxis]

    return ryb + black[:, np.newaxis]


def ryb_to_rgb(ryb):
    """
    Converts colours from RYB colorspace to RGB

    Parameters
    ----------

    ryb
        numpy array Nx3

    Returns
    -------
    Numpy array Nx3
    """
    ryb = np.array(ryb)
    if len(ryb.shape) == 1:
        ryb = ryb[np.newaxis, :]

    black = ryb.min(axis=1)
    white = (1 - ryb).min(axis=1)
    ryb = ryb - black[:, np.newaxis]

    green = ryb[:, 1:].min(axis=1)
    rgb = np.zeros_like(ryb)
    rgb[:, 0] = ryb[:, 0] + ryb[:, 1] - green
    rgb[:, 1] = green + ryb[:, 1]
    rgb[:, 2] = (ryb[:, 2] - green) * 2

    mask = ~(ryb == 0).all(axis=1)
    if mask.any():
        norm = rgb[mask].max(axis=1) / ryb[mask].max(axis=1)
        rgb[mask] = rgb[mask] / norm[:, np.newaxis]

    return rgb + white[:, np.newaxis]


[docs]def plot_spatial_general( value_df, coords, labels, text=None, circle_diameter=4.0, alpha_scaling=1.0, max_col=(np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf), max_color_quantile=0.98, show_img=True, img=None, img_alpha=1.0, adjust_text=False, plt_axis="off", axis_y_flipped=True, x_y_labels=("", ""), crop_x=None, crop_y=None, text_box_alpha=0.9, reorder_cmap=range(7), style="fast", colorbar_position="bottom", colorbar_label_kw={}, colorbar_shape={}, colorbar_tick_size=12, colorbar_grid=None, image_cmap="Greys_r", white_spacing=20, ): r"""Plot spatial abundance of cell types (regulatory programmes) with colour gradient and interpolation. This method supports only 7 cell types with these colours (in order, which can be changed using reorder_cmap). 'yellow' 'orange' 'blue' 'green' 'purple' 'grey' 'white' :param value_df: pd.DataFrame - with cell abundance or other features (only 7 allowed, columns) across locations (rows) :param coords: np.ndarray - x and y coordinates (in columns) to be used for ploting spots :param text: pd.DataFrame - with x, y coordinates, text to be printed :param circle_diameter: diameter of circles :param labels: list of strings, labels of cell types :param alpha_scaling: adjust color alpha :param max_col: crops the colorscale maximum value for each column in value_df. :param max_color_quantile: crops the colorscale at x quantile of the data. :param show_img: show image? :param img: numpy array representing a tissue image. If not provided a black background image is used. :param img_alpha: transparency of the image :param lim: x and y max limits on the plot. Minimum is always set to 0, if `lim` is None maximum is set to image height and width. If 'no_limit' then no limit is set. :param adjust_text: move text label to prevent overlap :param plt_axis: show axes? :param axis_y_flipped: flip y axis to match coordinates of the plotted image :param reorder_cmap: reorder colors to make sure you get the right color for each category :param style: plot style (matplolib.style.context): 'fast' - white background & dark text; 'dark_background' - black background & white text; :param colorbar_position: 'bottom', 'right' or None :param colorbar_label_kw: dict that will be forwarded to ax.set_label() :param colorbar_shape: dict {'vertical_gaps': 1.5, 'horizontal_gaps': 1.5, 'width': 0.2, 'height': 0.2}, not obligatory to contain all params :param colorbar_tick_size: colorbar ticks label size :param colorbar_grid: tuple of colorbar grid (rows, columns) :param image_cmap: matplotlib colormap for grayscale image :param white_spacing: percent of colorbars to be hidden """ if value_df.shape[1] > 7: raise ValueError("Maximum of 7 cell types / factors can be plotted at the moment") def create_colormap(R, G, B): spacing = int(white_spacing * 2.55) N = 255 M = 3 alphas = np.concatenate([[0] * spacing * M, np.linspace(0, 1.0, (N - spacing) * M)]) vals = np.ones((N * M, 4)) # vals[:, 0] = np.linspace(1, R / 255, N * M) # vals[:, 1] = np.linspace(1, G / 255, N * M) # vals[:, 2] = np.linspace(1, B / 255, N * M) for i, color in enumerate([R, G, B]): vals[:, i] = color / 255 vals[:, 3] = alphas return ListedColormap(vals) # Create linearly scaled colormaps YellowCM = create_colormap(240, 228, 66) # #F0E442 ['#F0E442', '#D55E00', '#56B4E9', # '#009E73', '#5A14A5', '#C8C8C8', '#323232'] RedCM = create_colormap(213, 94, 0) # #D55E00 BlueCM = create_colormap(86, 180, 233) # #56B4E9 GreenCM = create_colormap(0, 158, 115) # #009E73 GreyCM = create_colormap(200, 200, 200) # #C8C8C8 WhiteCM = create_colormap(50, 50, 50) # #323232 PurpleCM = create_colormap(90, 20, 165) # #5A14A5 cmaps = [YellowCM, RedCM, BlueCM, GreenCM, PurpleCM, GreyCM, WhiteCM] cmaps = [cmaps[i] for i in reorder_cmap] with mpl.style.context(style): fig = plt.figure() if colorbar_position == "right": if colorbar_grid is None: colorbar_grid = (len(labels), 1) shape = {"vertical_gaps": 1.5, "horizontal_gaps": 0, "width": 0.15, "height": 0.2} shape = {**shape, **colorbar_shape} gs = GridSpec( nrows=colorbar_grid[0] + 2, ncols=colorbar_grid[1] + 1, width_ratios=[1, *[shape["width"]] * colorbar_grid[1]], height_ratios=[1, *[shape["height"]] * colorbar_grid[0], 1], hspace=shape["vertical_gaps"], wspace=shape["horizontal_gaps"], ) ax = fig.add_subplot(gs[:, 0], aspect="equal", rasterized=True) if colorbar_position == "bottom": if colorbar_grid is None: if len(labels) <= 3: colorbar_grid = (1, len(labels)) else: n_rows = round(len(labels) / 3 + 0.5 - 1e-9) colorbar_grid = (n_rows, 3) shape = {"vertical_gaps": 0.3, "horizontal_gaps": 0.6, "width": 0.2, "height": 0.035} shape = {**shape, **colorbar_shape} gs = GridSpec( nrows=colorbar_grid[0] + 1, ncols=colorbar_grid[1] + 2, width_ratios=[0.3, *[shape["width"]] * colorbar_grid[1], 0.3], height_ratios=[1, *[shape["height"]] * colorbar_grid[0]], hspace=shape["vertical_gaps"], wspace=shape["horizontal_gaps"], ) ax = fig.add_subplot(gs[0, :], aspect="equal", rasterized=True) if colorbar_position is None: ax = fig.add_subplot(aspect="equal", rasterized=True) if colorbar_position is not None: cbar_axes = [] for row in range(1, colorbar_grid[0] + 1): for column in range(1, colorbar_grid[1] + 1): cbar_axes.append(fig.add_subplot(gs[row, column])) n_excess = colorbar_grid[0] * colorbar_grid[1] - len(labels) if n_excess > 0: for i in range(1, n_excess + 1): cbar_axes[-i].set_visible(False) ax.set_xlabel(x_y_labels[0]) ax.set_ylabel(x_y_labels[1]) if img is not None and show_img: ax.imshow(img, aspect="equal", alpha=img_alpha, origin="lower", cmap=image_cmap) # crop images in needed if crop_x is not None: ax.set_xlim(crop_x[0], crop_x[1]) if crop_y is not None: ax.set_ylim(crop_y[0], crop_y[1]) if axis_y_flipped: ax.invert_yaxis() if plt_axis == "off": for spine in ax.spines.values(): spine.set_visible(False) ax.tick_params(bottom=False, labelbottom=False, left=False, labelleft=False) counts = value_df.values.copy() # plot spots as circles c_ord = list(np.arange(0, counts.shape[1])) colors = np.zeros((*counts.shape, 4)) weights = np.zeros(counts.shape) for c in c_ord: min_color_intensity = counts[:, c].min() max_color_intensity = np.min([np.quantile(counts[:, c], max_color_quantile), max_col[c]]) rgb_function = get_rgb_function(cmap=cmaps[c], min_value=min_color_intensity, max_value=max_color_intensity) color = rgb_function(counts[:, c]) color[:, 3] = color[:, 3] * alpha_scaling norm = mpl.colors.Normalize(vmin=min_color_intensity, vmax=max_color_intensity) if colorbar_position is not None: cbar_ticks = [ min_color_intensity, np.mean([min_color_intensity, max_color_intensity]), max_color_intensity, ] cbar_ticks = np.array(cbar_ticks) if max_color_intensity > 13: cbar_ticks = cbar_ticks.astype(np.int32) else: cbar_ticks = cbar_ticks.round(2) cbar = fig.colorbar( mpl.cm.ScalarMappable(norm=norm, cmap=cmaps[c]), cax=cbar_axes[c], orientation="horizontal", extend="both", ticks=cbar_ticks, ) cbar.ax.tick_params(labelsize=colorbar_tick_size) max_color = rgb_function(max_color_intensity / 1.5) cbar.ax.set_title(labels[c], **{**{"size": 20, "color": max_color, "alpha": 1}, **colorbar_label_kw}) colors[:, c] = color weights[:, c] = np.clip(counts[:, c] / (max_color_intensity + 1e-10), 0, 1) weights[:, c][counts[:, c] < min_color_intensity] = 0 colors_ryb = np.zeros((*weights.shape, 3)) for i in range(colors.shape[0]): colors_ryb[i] = rgb_to_ryb(colors[i, :, :3]) def kernel(w): return w**2 kernel_weights = kernel(weights[:, :, np.newaxis]) weighted_colors_ryb = (colors_ryb * kernel_weights).sum(axis=1) / kernel_weights.sum(axis=1) weighted_colors = np.zeros((weights.shape[0], 4)) weighted_colors[:, :3] = ryb_to_rgb(weighted_colors_ryb) weighted_colors[:, 3] = colors[:, :, 3].max(axis=1) ax.scatter(x=coords[:, 0], y=coords[:, 1], c=weighted_colors, s=circle_diameter**2) # add text if text is not None: bbox_props = dict(boxstyle="round", ec="0.5", alpha=text_box_alpha, fc="w") texts = [] for x, y, s in zip( np.array(text.iloc[:, 0].values).flatten(), np.array(text.iloc[:, 1].values).flatten(), text.iloc[:, 2].tolist(), ): texts.append(ax.text(x, y, s, ha="center", va="bottom", bbox=bbox_props)) if adjust_text: from adjustText import adjust_text adjust_text(texts, arrowprops=dict(arrowstyle="->", color="w", lw=0.5)) return fig
[docs]def plot_spatial(adata, color, img_key="hires", show_img=True, **kwargs): """Plot spatial abundance of cell types (regulatory programmes) with colour gradient and interpolation (from Visium anndata). This method supports only 7 cell types with these colours (in order, which can be changed using reorder_cmap). 'yellow' 'orange' 'blue' 'green' 'purple' 'grey' 'white' :param adata: adata object with spatial coordinates in adata.obsm['spatial'] :param color: list of adata.obs column names to be plotted :param kwargs: arguments to plot_spatial_general :return: matplotlib figure """ if show_img is True: kwargs["show_img"] = True kwargs["img"] = list(adata.uns["spatial"].values())[0]["images"][img_key] # location coordinates if "spatial" in adata.uns.keys(): kwargs["coords"] = ( adata.obsm["spatial"] * list(adata.uns["spatial"].values())[0]["scalefactors"][f"tissue_{img_key}_scalef"] ) else: kwargs["coords"] = adata.obsm["spatial"] fig = plot_spatial_general(value_df=adata.obs[color], **kwargs) # cell abundance values return fig