Mapping human lymph node cell types to 10X Visium with Cell2location

This tutorial shows how to use cell2location method for spatially resolving fine-grained cell types by integrating 10X Visium data with scRNA-seq reference of cell types. Cell2location is a principled Bayesian model that estimates which combination of cell types in which cell abundance could have given the mRNA counts in the spatial data, while modelling technical effects (platform/technology effect, contaminating RNA, unexplained variance).


Cell2location is an independent package, but is powered by scvi-tools. If you have questions about cell2location, Visium data or scvi-tools please visit, or correspondingly.

Open In Colab

In this tutorial, we analyse a publicly available Visium dataset of the human lymph node from 10X Genomics, and spatially map a comprehensive atlas of 34 reference cell types derived by integration of scRNA-seq datasets from human secondary lymphoid organs.

  • Cell2location provides high sensitivity and resolution by borrowing statistical strength across locations. This is achieved by modelling similarity of location patterns between cell types using a hierarchical factorisation of cell abundance into tissue zones as a prior (see paper methods).

  • Using our statistical method based on Negative Binomial regression to robustly combine scRNA-seq reference data across technologies and batches results in improved spatial mapping accuracy. Given cell type annotation for each cell, the corresponding reference cell type signatures \(g_{f,g}\), which represent the average mRNA count of each gene \(g\) in each cell type \(f\), can be estimated from sc/snRNA-seq data using either 1) NB regression or 2) a hard-coded computation of per-cluster average mRNA counts for individual genes. We generally recommend using NB regression. This notebook shows use a dataset composed on multiple batches and technologies.When the batch effects are small, a faster hard-coded method of computing per cluster averages provides similarly high accuracy. We also recommend the hard-coded method for non-UMI technologies such as Smart-Seq 2.

  • Cell2location needs untransformed unnormalised spatial mRNA counts as input.

  • You also need to provide cell2location with the expected average cell abundance per location which is used as a prior to guide estimation of absolute cell abundance. This value depends on the tissue and can be estimated by counting nuclei for a few locations in the paired histology image but can be approximate (see paper methods for more guidance).

Workflow diagram

Figure 1.

Loading packages

import sys
IN_COLAB = "google.colab" in sys.modules
    !pip install --quiet scvi-colab
    from scvi_colab import install
    !pip install --quiet git+[tutorials]
import scanpy as sc
import anndata
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import cell2location
import scvi

from matplotlib import rcParams
rcParams['pdf.fonttype'] = 42 # enables correct plotting of text for PDFs
Global seed set to 0

First, let’s define where we save the results of our analysis:

results_folder = './results/lymph_nodes_analysis/'

# create paths and names to results folders for reference regression and cell2location models
ref_run_name = f'{results_folder}/reference_signatures'
run_name = f'{results_folder}/cell2location_map'

Loading Visium and scRNA-seq reference data

First let’s read spatial Visium data from 10X Space Ranger output. Here we use lymph node data generated by 10X and presented in Kleshchevnikov et al (section 4, Fig 4). This dataset can be conveniently downloaded and imported using scanpy. See this tutorial for a more extensive and practical example of data loading (multiple visium samples).

adata_vis = sc.datasets.visium_sge(sample_id="V1_Human_Lymph_Node")
adata_vis.obs['sample'] = list(adata_vis.uns['spatial'].keys())[0]
/nfs/team283/vk7/software/miniconda3farm5/envs/test_test_scvi16_cuda113/lib/python3.9/site-packages/anndata/_core/ UserWarning: Variable names are not unique. To make them unique, call `.var_names_make_unique`.


Here we rename genes to ENSEMBL ID for correct matching between single cell and spatial data - so you can ignore the scanpy suggestion to call .var_names_make_unique.

adata_vis.var['SYMBOL'] = adata_vis.var_names
adata_vis.var.set_index('gene_ids', drop=True, inplace=True)

You can still plot gene expression by name using standard scanpy functions as follows:'PTPRC', gene_symbols='SYMBOL', ...)


Mitochondia-encoded genes (gene names start with prefix mt- or MT-) are irrelevant for spatial mapping because their expression represents technical artifacts in the single cell and nucleus data rather than biological abundance of mitochondria. Yet these genes compose 15-40% of mRNA in each location. Hence, to avoid mapping artifacts we strongly recommend removing mitochondrial genes.

# find mitochondria-encoded (MT) genes
adata_vis.var['MT_gene'] = [gene.startswith('MT-') for gene in adata_vis.var['SYMBOL']]

# remove MT genes for spatial mapping (keeping their counts in the object)
adata_vis.obsm['MT'] = adata_vis[:, adata_vis.var['MT_gene'].values].X.toarray()
adata_vis = adata_vis[:, ~adata_vis.var['MT_gene'].values]

Published scRNA-seq datasets of lymph nodes have typically lacked an adequate representation of germinal centre-associated immune cell populations due to age of patient donors. We, therefore, include scRNA-seq datasets spanning lymph nodes, spleen and tonsils in our single-cell reference to ensure that we captured the full diversity of immune cell states likely to exist in the spatial transcriptomic dataset.

Here we download this dataset, import into anndata and change variable names to ENSEMBL gene identifiers.

# Read data
adata_ref =


Here we rename genes to ENSEMBL ID for correct matching between single cell and spatial data.

adata_ref.var['SYMBOL'] = adata_ref.var.index
# rename 'GeneID-2' as necessary for your data
adata_ref.var.set_index('GeneID-2', drop=True, inplace=True)

# delete unnecessary raw slot (to be removed in a future version of the tutorial)
del adata_ref.raw


Before we estimate the reference cell type signature we recommend to perform very permissive genes selection. We prefer this to standard highly-variable-gene selection because our procedure keeps markers of rare genes while removing most of the uninformative genes.

The default parameters cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12 are a good starting point, however, you can increase the cut-off to exclude more genes. To preserve marker genes of rare cell types we recommend low cell_count_cutoff=5, however, cell_percentage_cutoff2 and nonz_mean_cutoff can be increased to select between 8k-16k genes.

In this 2D histogram, orange rectangle highlights genes excluded based on the combination of number of cells expressing that gene (Y-axis) and average RNA count for cells where the gene was detected (X-axis).

In this case, the downloaded dataset was already filtered using this method, hence no density under the orange rectangle (to be changed in the future version of the tutorial).

from cell2location.utils.filtering import filter_genes
selected = filter_genes(adata_ref, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)

# filter the object
adata_ref = adata_ref[:, selected].copy()

Estimation of reference cell type signatures (NB regression)

The signatures are estimated from scRNA-seq data, accounting for batch effect, using a Negative binomial regression model.

Preparing anndata.

First, prepare anndata object for the regression model:

# prepare anndata for the regression model
                        # 10X reaction / sample / batch
                        # cell type, covariate used for constructing signatures
                        # multiplicative technical effects (platform, 3' vs 5', donor effect)
# create the regression model
from cell2location.models import RegressionModel
mod = RegressionModel(adata_ref)

# view anndata_setup as a sanity check
Anndata setup with scvi-tools version 0.16.2.
Setup via `RegressionModel.setup_anndata` with arguments:
'layer': None,
'batch_key': 'Sample',
'labels_key': 'Subset',
'categorical_covariate_keys': ['Method'],
'continuous_covariate_keys': None
         Summary Statistics         
┃     Summary Stat Key      Value ┃
│         n_cells           73260 │
│          n_vars           10237 │
│         n_batch            23   │
│         n_labels           34   │
│ n_extra_categorical_covs    1   │
│ n_extra_continuous_covs     0   │
                             Data Registry                             
┃      Registry Key                  scvi-tools Location             ┃
│           X                              adata.X                   │
│         batch                    adata.obs['_scvi_batch']          │
│         labels                  adata.obs['_scvi_labels']          │
│ extra_categorical_covs  adata.obsm['_scvi_extra_categorical_covs'] │
│         ind_x                     adata.obs['_indices']            │
                         batch State Registry                         
┃   Source Location          Categories        scvi-tools Encoding ┃
│ adata.obs['Sample']     4861STDY7135913               0          │
│                         4861STDY7135914               1          │
│                         4861STDY7208412               2          │
│                         4861STDY7208413               3          │
│                         4861STDY7462253               4          │
│                         4861STDY7462254               5          │
│                         4861STDY7462255               6          │
│                         4861STDY7462256               7          │
│                         4861STDY7528597               8          │
│                         4861STDY7528598               9          │
│                         4861STDY7528599              10          │
│                         4861STDY7528600              11          │
│                           BCP002_Total               12          │
│                           BCP003_Total               13          │
│                           BCP004_Total               14          │
│                           BCP005_Total               15          │
│                           BCP006_Total               16          │
│                           BCP008_Total               17          │
│                           BCP009_Total               18          │
│                      Human_colon_16S7255677          19          │
│                      Human_colon_16S7255678          20          │
│                      Human_colon_16S8000484          21          │
│                           Pan_T7935494               22          │
                     labels State Registry                      
┃   Source Location       Categories     scvi-tools Encoding ┃
│ adata.obs['Subset']     B_Cycling               0          │
│                          B_GC_DZ                1          │
│                          B_GC_LZ                2          │
│                         B_GC_prePB              3          │
│                           B_IFN                 4          │
│                        B_activated              5          │
│                           B_mem                 6          │
│                          B_naive                7          │
│                          B_plasma               8          │
│                          B_preGC                9          │
│                          DC_CCR7+              10          │
│                          DC_cDC1               11          │
│                          DC_cDC2               12          │
│                           DC_pDC               13          │
│                            Endo                14          │
│                            FDC                 15          │
│                            ILC                 16          │
│                       Macrophages_M1           17          │
│                       Macrophages_M2           18          │
│                            Mast                19          │
│                         Monocytes              20          │
│                             NK                 21          │
│                            NKT                 22          │
│                           T_CD4+               23          │
│                         T_CD4+_TfH             24          │
│                       T_CD4+_TfH_GC            25          │
│                        T_CD4+_naive            26          │
│                       T_CD8+_CD161+            27          │
│                      T_CD8+_cytotoxic          28          │
│                        T_CD8+_naive            29          │
│                          T_TIM3+               30          │
│                           T_TfR                31          │
│                           T_Treg               32          │
│                            VSMC                33          │
          extra_categorical_covs State Registry           
┃   Source Location    Categories  scvi-tools Encoding ┃
│ adata.obs['Method']     3GEX              0          │
│                         5GEX              1          │
│                                                      │

Training model.

Now we train the model to estimate the reference cell type signatures.

Note that to achieve convergence on your data (=to get stabilization of the loss) you may need to increase max_epochs=250 (See below).

Also note that here we are using batch_size=2500 which is much larger than scvi-tools default and perform training on all cells in the data (train_size=1) - both parameters are defaults.

mod.train(max_epochs=250, use_gpu=True)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
/nfs/team283/vk7/software/miniconda3farm5/envs/test_test_scvi16_cuda113/lib/python3.9/site-packages/pytorch_lightning/trainer/ UserWarning: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
Epoch 1/250:   0%|                                                                                       | 0/250 [00:00<?, ?it/s]
/nfs/team283/vk7/software/miniconda3farm5/envs/test_test_scvi16_cuda113/lib/python3.9/site-packages/torch/distributions/ UserWarning: Specified kernel cache directory could not be created! This disables kernel caching. Specified directory is /nfs/users/nfs_v/vk7/.cache/torch/kernels. This warning will appear only once per process. (Triggered internally at  ../aten/src/ATen/native/cuda/jit_utils.cpp:860.)
  self.rate * value - torch.lgamma(self.concentration))
Epoch 250/250: 100%|██████████████████████████████████████████████| 250/250 [13:14<00:00,  3.18s/it, v_num=1, elbo_train=2.88e+8]

Determine if the model needs more training.

Here, we plot ELBO loss history during training, removing first 20 epochs from the plot. This plot should have a decreasing trend and level off by the end of training. If it is still decreasing, increase max_epochs.

# In this section, we export the estimated cell abundance (summary of the posterior distribution).
adata_ref = mod.export_posterior(
    adata_ref, sample_kwargs={'num_samples': 1000, 'batch_size': 2500, 'use_gpu': True}

# Save model"{ref_run_name}", overwrite=True)

# Save anndata object with results
adata_file = f"{ref_run_name}/sc.h5ad"
Sampling local variables, batch:   0%|                                                                    | 0/30 [00:00<?, ?it/s]
Sampling global variables, sample: 100%|██████████████████████████████████████████████████████| 999/999 [00:08<00:00, 121.60it/s]

Examine QC plots.

  1. Reconstruction accuracy to assess if there are any issues with inference. This 2D histogram plot should have most observations along a noisy diagonal.

  2. The estimated expression signatures are distinct from mean expression in each cluster because of batch effects. For scRNA-seq datasets which do not suffer from batch effect (this dataset does), cluster average expression can be used instead of estimating signatures with a model. When this plot is very different from a diagonal plot (e.g. very low values on Y-axis, density everywhere) it indicates problems with signature estimation.


The model and output h5ad can be loaded later like this:

adata_file = f"{ref_run_name}/sc.h5ad"
adata_ref = sc.read_h5ad(adata_file)
mod = cell2location.models.RegressionModel.load(f"{ref_run_name}", adata_ref)

Extracting reference cell types signatures as a pd.DataFrame.

All parameters of the a Negative Binomial regression model are exported into reference anndata object, however for spatial mapping we just need the estimated expression of every gene in every cell type. Here we extract that from standard output:

# export estimated expression in each cluster
if 'means_per_cluster_mu_fg' in adata_ref.varm.keys():
    inf_aver = adata_ref.varm['means_per_cluster_mu_fg'][[f'means_per_cluster_mu_fg_{i}'
                                    for i in adata_ref.uns['mod']['factor_names']]].copy()
    inf_aver = adata_ref.var[[f'means_per_cluster_mu_fg_{i}'
                                    for i in adata_ref.uns['mod']['factor_names']]].copy()
inf_aver.columns = adata_ref.uns['mod']['factor_names']
inf_aver.iloc[0:5, 0:5]
B_Cycling B_GC_DZ B_GC_LZ B_GC_prePB B_IFN
ENSG00000188976 0.422678 0.238245 0.304522 0.341516 0.148354
ENSG00000188290 0.002084 0.000712 0.000780 0.055247 0.040109
ENSG00000187608 0.384230 0.211964 0.274831 0.510336 3.942888
ENSG00000186891 0.019506 0.000771 0.053882 0.067228 0.010869
ENSG00000186827 0.007557 0.000531 0.006252 0.029405 0.011207

Cell2location: spatial mapping

Find shared genes and prepare anndata. Subset both anndata and reference signatures:

# find shared genes and subset both anndata and reference signatures
intersect = np.intersect1d(adata_vis.var_names, inf_aver.index)
adata_vis = adata_vis[:, intersect].copy()
inf_aver = inf_aver.loc[intersect, :].copy()

# prepare anndata for cell2location model
cell2location.models.Cell2location.setup_anndata(adata=adata_vis, batch_key="sample")

To use cell2location spatial mapping model, you need to specify 2 user-provided hyperparameters (N_cells_per_location and detection_alpha) - for detailed guidance on setting these hyperparameters and their impact see the flow diagram and the note.

Choosing hyperparameter ``N_cells_per_location``!

It is useful to adapt the expected cell abundance N_cells_per_location to every tissue. This value can be estimated from paired histology images and as described in the note above. Change the value presented in this tutorial (N_cells_per_location=30) to the value observed in your your tissue.

Choosing hyperparameter ``detection_alpha``!

To improve accuracy & sensitivity on datasets with large technical variability in RNA detection sensitivity within the slide/batch - you need to relax regularisation of per-location normalisation (use detection_alpha=20). High technical variability in RNA detection sensitivity is present in your sample when you observe the spatial distribution of total RNA count per location that doesn’t match expected cell numbers based on histological examination.

We initially opted for high regularisation (detection_alpha=200) as a default because the mouse brain & human lymph node datasets used in our paper have low technical effects and using high regularisation strenght improves consistencly between total estimated cell abundance per location and the nuclei count quantified from histology (Fig S8F in cell2location paper). However, in many collaborations, we see that Visium experiments on human tissues suffer from technical effects. This motivates the new default value of detection_alpha=20 and the recommendation of testing both settings on your data (detection_alpha=20 and detection_alpha=200).

# create and train the model
mod = cell2location.models.Cell2location(
    adata_vis, cell_state_df=inf_aver,
    # the expected average cell abundance: tissue-dependent
    # hyper-prior which can be estimated from paired histology:
    # hyperparameter controlling normalisation of
    # within-experiment variation in RNA detection:
Anndata setup with scvi-tools version 0.16.2.
Setup via `Cell2location.setup_anndata` with arguments:
'layer': None,
'batch_key': 'sample',
'labels_key': None,
'categorical_covariate_keys': None,
'continuous_covariate_keys': None
         Summary Statistics         
┃     Summary Stat Key      Value ┃
│         n_cells           4035  │
│          n_vars           10217 │
│         n_batch             1   │
│         n_labels            1   │
│ n_extra_categorical_covs    0   │
│ n_extra_continuous_covs     0   │
               Data Registry                
┃ Registry Key     scvi-tools Location    ┃
│      X                 adata.X          │
│    batch      adata.obs['_scvi_batch']  │
│    labels     adata.obs['_scvi_labels'] │
│    ind_x        adata.obs['_indices']   │
                       batch State Registry                        
┃   Source Location        Categories       scvi-tools Encoding ┃
│ adata.obs['sample']  V1_Human_Lymph_Node           0          │
                     labels State Registry                      
┃      Source Location       Categories  scvi-tools Encoding ┃
│ adata.obs['_scvi_labels']      0                0          │

Training cell2location:

          # train using full data (batch_size=None)
          # use all data points in training because
          # we need to estimate cell abundance at all locations

# plot ELBO loss history during training, removing first 100 epochs from the plot
plt.legend(labels=['full data training']);
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
/nfs/team283/vk7/software/miniconda3farm5/envs/test_test_scvi16_cuda113/lib/python3.9/site-packages/pytorch_lightning/trainer/ UserWarning: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
/nfs/team283/vk7/software/miniconda3farm5/envs/test_test_scvi16_cuda113/lib/python3.9/site-packages/pytorch_lightning/trainer/ UserWarning: The number of training samples (1) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Epoch 30000/30000: 100%|██████████████████████████████████████| 30000/30000 [32:37<00:00, 15.32it/s, v_num=1, elbo_train=4.54e+7]

Exporting estimated posterior distributions of cell abundance and saving results:

# In this section, we export the estimated cell abundance (summary of the posterior distribution).
adata_vis = mod.export_posterior(
    adata_vis, sample_kwargs={'num_samples': 1000, 'batch_size': mod.adata.n_obs, 'use_gpu': True}

# Save model"{run_name}", overwrite=True)

# mod = cell2location.models.Cell2location.load(f"{run_name}", adata_vis)

# Save anndata object with results
adata_file = f"{run_name}/sp.h5ad"
Sampling local variables, batch: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:14<00:00, 14.20s/it]
Sampling global variables, sample: 100%|███████████████████████████████████████████████████████| 999/999 [00:14<00:00, 71.26it/s]

The model and output h5ad can be loaded later like this:

adata_file = f"{run_name}/sp.h5ad"
adata_vis = sc.read_h5ad(adata_file)
mod = cell2location.models.Cell2location.load(f"{run_name}", adata_vis)

Assessing mapping quality. Examine reconstruction accuracy to assess if there are any issues with mapping. The plot should be roughly diagonal, strong deviations will signal problems that need to be investigated.


When intergrating multiple spatial batches and when working with datasets that have substantial variation of detected RNA within slides (that cannot be explained by high cellular density in the histology), it is important to assess whether cell2location normalised those effects. You expect to see similar total cell abundance across batches but distinct RNA detection sensitivity (both estimated by cell2location). You expect total cell abundance to mirror high cellular density in the histology.

fig = mod.plot_spatial_QC_across_batches()

Visualising cell abundance in spatial coordinates


We use 5% quantile of the posterior distribution, representing the value of cell abundance that the model has high confidence in (aka ‘at least this amount is present’).

# add 5% quantile, representing confident cell abundance, 'at least this amount is present',
# to adata.obs with nice names for plotting
adata_vis.obs[adata_vis.uns['mod']['factor_names']] = adata_vis.obsm['q05_cell_abundance_w_sf']

# select one slide
from cell2location.utils import select_slide
slide = select_slide(adata_vis, 'V1_Human_Lymph_Node')

# plot in spatial coordinates
with mpl.rc_context({'axes.facecolor':  'black',
                     'figure.figsize': [4.5, 5]}):, cmap='magma',
                  # show first 8 cell types
                  color=['B_Cycling', 'B_GC_LZ', 'T_CD4+_TfH_GC', 'FDC',
                         'B_naive', 'T_CD4+_naive', 'B_plasma', 'Endo'],
                  ncols=4, size=1.3,
                  # limit color scale at 99.2% quantile of cell abundance
                  vmin=0, vmax='p99.2'
# Now we use cell2location plotter that allows showing multiple cell types in one panel
from cell2location.plt import plot_spatial

# select up to 6 clusters
clust_labels = ['T_CD4+_naive', 'B_naive', 'FDC']
clust_col = ['' + str(i) for i in clust_labels] # in case column names differ from labels

slide = select_slide(adata_vis, 'V1_Human_Lymph_Node')

with mpl.rc_context({'figure.figsize': (15, 15)}):
    fig = plot_spatial(
        # labels to show on a plot
        color=clust_col, labels=clust_labels,
        # 'fast' (white background) or 'dark_background'
        # limit color scale at 99.2% quantile of cell abundance
        # size of locations (adjust depending on figure size)

Downstream analysis

Identifying discrete tissue regions by Leiden clustering

We identify tissue regions that differ in their cell composition by clustering locations using cell abundance estimated by cell2location.

We find tissue regions by clustering Visium spots using estimated cell abundance each cell type. We constuct a K-nearest neigbour (KNN) graph representing similarity of locations in estimated cell abundance and then apply Leiden clustering. The number of KNN neighbours should be adapted to size of dataset and the size of anatomically defined regions (e.i. hippocampus regions are rather small compared to size of the brain so could be masked by large n_neighbors). This can be done for a range KNN neighbours and Leiden clustering resolutions until a clustering matching the anatomical structure of the tissue is obtained.

The clustering is done jointly across all Visium sections / batches, hence the region identities are directly comparable. When there are strong technical effects between multiple batches (not the case here) sc.external.pp.bbknn can be in principle used to account for those effects during the KNN construction.

The resulting clusters are saved in adata_vis.obs['region_cluster'].

# compute KNN using the cell2location output stored in adata.obsm
sc.pp.neighbors(adata_vis, use_rep='q05_cell_abundance_w_sf',
                n_neighbors = 15)

# Cluster spots into regions using scanpy, resolution=1.1)

# add region as categorical variable
adata_vis.obs["region_cluster"] = adata_vis.obs["leiden"].astype("category")
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.

We can use the location composition similarity graph to build a joint integrated UMAP representation of all section/Visium batches.

# compute UMAP using KNN graph based on the cell2location output, min_dist = 0.3, spread = 1)

# show regions in UMAP coordinates
with mpl.rc_context({'axes.facecolor':  'white',
                     'figure.figsize': [8, 8]}):, color=['region_cluster'], size=30,
               color_map = 'RdPu', ncols = 2, legend_loc='on data',
               legend_fontsize=20), color=['sample'], size=30,
               color_map = 'RdPu', ncols = 2,

# plot in spatial coordinates
with mpl.rc_context({'axes.facecolor':  'black',
                     'figure.figsize': [4.5, 5]}):, color=['region_cluster'],
                  size=1.3, img_key='hires', alpha=0.5)

Identifying cellular compartments / tissue zones using matrix factorisation (NMF)

Here, we use the cell2location mapping results to identify the spatial co-occurrence of cell types in order to better understand the tissue organisation and predict cellular interactions. We performed non-negative matrix factorization (NMF) of the cell type abundance estimates from cell2location (paper section 4, Fig 4D). Similar to the established benefits of applying NMF to conventional scRNA-seq, the additive NMF decomposition yielded a grouping of spatial cell type abundance profiles into components that capture co-localised cell types (Supplemenary Methods section 4.2, p. 60). This NMF-based decomposition naturally accounts for the fact that multiple cell types and microenvironments can co-exist at the same Visium locations (see paper Fig S20, p. 34), while sharing information across tissue areas (e.g. individual germinal centres).


In practice, it is better to train NMF for a range of factors \(R={5, .., 30}\) and select \(R\) as a balance between capturing fine-grained and splitting known well-established tissue zones.

If you want to find a few most disctinct cellular compartments, use a small number of factors. If you want to find very strong co-location signal and assume that most cell types don’t co-locate, use a lot of factors (> 30 - used here).

Below we show how to perform this analysis. To aid this analysis, we wrapped the analysis shown the notebook on advanced downstream analysis into a pipeline that automates training of the NMF model with varying number of factors:

from cell2location import run_colocation
res_dict, adata_vis = run_colocation(
      'n_fact': np.arange(11, 13), # IMPORTANT: use a wider range of the number of factors (5-30)
      'sample_name_col': 'sample', # columns in adata_vis.obs that identifies sample
      'n_restarts': 3 # number of training restarts
    export_args={'path': f'{run_name}/CoLocatedComb/'}
### Analysis name: CoLocatedGroupsSklearnNMF_11combinations_4035locations_34factors
/nfs/team283/vk7/software/miniconda3farm5/envs/test_test_scvi16_cuda113/lib/python3.9/site-packages/sklearn/decomposition/ FutureWarning: `alpha` was deprecated in version 1.0 and will be removed in 1.2. Use `alpha_W` and `alpha_H` instead
/nfs/team283/vk7/software/miniconda3farm5/envs/test_test_scvi16_cuda113/lib/python3.9/site-packages/sklearn/decomposition/ FutureWarning: `alpha` was deprecated in version 1.0 and will be removed in 1.2. Use `alpha_W` and `alpha_H` instead
/nfs/team283/vk7/software/miniconda3farm5/envs/test_test_scvi16_cuda113/lib/python3.9/site-packages/sklearn/decomposition/ FutureWarning: `alpha` was deprecated in version 1.0 and will be removed in 1.2. Use `alpha_W` and `alpha_H` instead
WARNING: saving figure to file results/lymph_nodes_analysis/cell2location_map/CoLocatedComb/CoLocatedGroupsSklearnNMF_4035locations_34factors/spatial/showcell_density_mean_n_fact11_sV1_Human_Lymph_Node_p99.2.pdf
### Analysis name: CoLocatedGroupsSklearnNMF_12combinations_4035locations_34factors
/nfs/team283/vk7/software/miniconda3farm5/envs/test_test_scvi16_cuda113/lib/python3.9/site-packages/sklearn/decomposition/ FutureWarning: `alpha` was deprecated in version 1.0 and will be removed in 1.2. Use `alpha_W` and `alpha_H` instead
/nfs/team283/vk7/software/miniconda3farm5/envs/test_test_scvi16_cuda113/lib/python3.9/site-packages/sklearn/decomposition/ FutureWarning: `alpha` was deprecated in version 1.0 and will be removed in 1.2. Use `alpha_W` and `alpha_H` instead
/nfs/team283/vk7/software/miniconda3farm5/envs/test_test_scvi16_cuda113/lib/python3.9/site-packages/sklearn/decomposition/ FutureWarning: `alpha` was deprecated in version 1.0 and will be removed in 1.2. Use `alpha_W` and `alpha_H` instead
WARNING: saving figure to file results/lymph_nodes_analysis/cell2location_map/CoLocatedComb/CoLocatedGroupsSklearnNMF_4035locations_34factors/spatial/showcell_density_mean_n_fact12_sV1_Human_Lymph_Node_p99.2.pdf

For every factor number, the model produces the following list of folder outputs:

cell_type_fractions_heatmap/: a dot plot of the estimated NMF weights of cell types (rows) across NMF components (columns)

cell_type_fractions_mean/: the data used for dot plot

factor_markers/: tables listing top 10 cell types most speficic to each NMF factor

models/: saved NMF models

predictive_accuracy/: 2D histogram plot showing how well NMF explains cell2location output

spatial/: NMF weights across locatinos in spatial coordinates

location_factors_mean/: the data used for the plot in spatial coordiantes

stability_plots/: stability of NMF weights between training restarts

Key output that you want to examine are the files in cell_type_fractions_heatmap/ which show a dot plot of the estimated NMF weights of cell types (rows) across NMF components (columns) which correspond to cellular compartments. Shown are relative weights, normalized across components for every cell type.


The NMF model output such as factor loadings are stored in adata.uns[f"mod_coloc_n_fact{n_fact}"] in a similar output format as main cell2location results in adata.uns['mod'].

# Here we plot the NMF weights (Same as saved to `cell_type_fractions_heatmap`)

Estimate cell-type specific expression of every gene in the spatial data (needed for NCEM)

The cell-type specific expression of every gene at every spatial location in the spatial data enables learning cell communication with NCEM model using Visium data (

To derive this, we adapt the approach of estimating conditional expected expression proposed by RCTD (Cable et al) method.

With cell2location, we can look at the posterior distribution rather than just point estimates of cell type specific expression (see mod.samples.keys() and next section on using full distribution).

Note that this analysis requires substantial amount of RAM memory and thefore doesn’t work on free Google Colab (12 GB limit).

# Compute expected expression per cell type
expected_dict = mod.module.model.compute_expected_per_cell_type(
    mod.samples["post_sample_q05"], mod.adata_manager

# Add to anndata layers
for i, n in enumerate(mod.factor_names_):
    adata_vis.layers[n] = expected_dict['mu'][i]

# Save anndata object with results
adata_file = f"{run_name}/sp.h5ad"

Plotting cell-type specific expression of genes in spatial coordinates.

Below we plot the cell-type specific expression of genes (rows, second to last columns) compared to total expression of those genes (first column).

Here we highlight CD3D, pan T-cell marker expressed by 2 subtypes of T cells in distinct locations but not expressed by co-located B cells, that instead express CR2 gene.

# list cell types and genes for plotting
ctypes = ['T_CD4+_TfH_GC', 'T_CD4+_naive', 'B_GC_LZ']
genes = ['CD3D', 'CR2']

with mpl.rc_context({'axes.facecolor':  'black'}):
    # select one slide
    slide = select_slide(adata_vis, 'V1_Human_Lymph_Node')

    from tutorial_utils import plot_genes_per_cell_type
    plot_genes_per_cell_type(slide, genes, ctypes);

Note that plot_genes_per_cell_type function often need customization so it is not included into cell2location package - you need to copy it from to use on your system.

Advanced use

Working with the posterior distribution and computing arbitrary quantiles

In addition to the posterior distribution mean, std and quantiles presented earlier in the notebook you can fetch an arbitrary number of samples from the posterior distribution. To limit memory use, it could be beneficial to select particular varibles in the model.

Note that this analysis requires substantial amount RAM memory and thefore doesn’t work on Google Colab.

# Get posterior distribution samples for specific variables
samples_w_sf = mod.sample_posterior(num_samples=1000, use_gpu=True, return_samples=True,
                                    return_sites=['w_sf', 'm_g', 'u_sf_mRNA_factors'])
# samples_w_sf['posterior_samples'] contains 1000 samples as arrays with dim=(num_samples, ...)
Sampling local variables, batch: 100%|█████████████████████████████████████████████████████████████| 2/2 [00:25<00:00, 12.72s/it]
Sampling global variables, sample: 100%|███████████████████████████████████████████████████████| 999/999 [00:12<00:00, 80.04it/s]
(1000, 4035, 34)

Finally, it could be useful to compute arbitrary quantiles of the posterior distribution.

# Compute any quantile of the posterior distribution
medians = mod.posterior_quantile(q=0.5, batch_size=mod.adata.n_obs, use_gpu=True)

with mpl.rc_context({'axes.facecolor':  'white',
                     'figure.figsize': [5, 5]}):
    plt.scatter(medians['w_sf'].flatten(), mod.samples['post_sample_means']['w_sf'].flatten());

Modules and their versions used for this analysis

Useful for debugging and reporting issues.

sys 3.9.12 (main, Apr  5 2022, 06:56:58)
[GCC 7.5.0]
re 2.2.1
ipykernel._version 6.13.0
json 2.0.9
jupyter_client._version 7.3.1
traitlets._version 5.2.1.post0
traitlets 5.2.1.post0
platform 1.0.8
_ctypes 1.1.0
ctypes 1.1.0
zmq.backend.cython.constants 40304
zmq.backend.cython 40304
zmq.sugar.constants 40304
zmq.sugar.version 22.3.0
zmq.sugar 22.3.0
zmq 22.3.0
socketserver 0.4
argparse 1.1
tornado 6.1
zlib 1.0
_curses b'2.2'
dateutil._version 2.8.2
dateutil 2.8.2
six 1.16.0
_decimal 1.70
decimal 1.70
jupyter_core.version 4.10.0
jupyter_core 4.10.0
entrypoints 0.4
jupyter_client 7.3.1
ipykernel 6.13.0
IPython.core.release 8.3.0
executing.version 0.8.3
executing 0.8.3
pure_eval.version 0.2.2
pure_eval 0.2.2
stack_data.version 0.2.0
stack_data 0.2.0
pygments 2.12.0
ptyprocess 0.7.0
pexpect 4.8.0
IPython.core.crashhandler 8.3.0
pickleshare 0.7.5
backcall 0.2.0
decorator 5.1.1
_sqlite3 2.6.0
sqlite3.dbapi2 2.6.0
sqlite3 2.6.0
wcwidth 0.2.5
prompt_toolkit 3.0.29
parso 0.8.3
jedi 0.18.1
urllib.request 3.9
IPython.core.magics.code 8.3.0
IPython 8.3.0
psutil 5.9.0
debugpy 1.6.0
xmlrpc.client 3.9
http.server 0.6
pkg_resources._vendor.appdirs 1.4.3
pkg_resources.extern.appdirs 1.4.3
pkg_resources._vendor.packaging.__about__ 21.2
pkg_resources._vendor.packaging 21.2
pkg_resources.extern.packaging 21.2
pkg_resources._vendor.pyparsing 2.2.1
pkg_resources.extern.pyparsing 2.2.1
_pydev_bundle.fsnotify 0.1.5
pydevd 2.8.0
packaging.__about__ 21.3
packaging 21.3
_csv 1.0
csv 1.0
scanpy._metadata 1.9.1
mkl 2.4.0
numpy.version 1.21.5
numpy.core._multiarray_umath 3.1
numpy.core 1.21.5
numpy.linalg._umath_linalg 0.1.5
numpy.lib 1.21.5
numpy 1.21.5
scipy.version 1.8.0
scipy 1.8.0
anndata._metadata 0.8.0
h5py.version 3.6.0
h5py 3.6.0
natsort 8.1.0
pytz 2022.1
numexpr.version 2.8.1
numexpr 2.8.1
pandas 1.4.2
anndata 0.8.0
yaml 6.0
llvmlite 0.38.0
numba.cloudpickle 1.6.0
numba.misc.appdirs 1.4.1
numba 0.55.1
distutils 3.9.12
joblib.externals.cloudpickle 2.0.0
joblib.externals.loky 3.0.0
joblib 1.1.0
sklearn.utils._joblib 1.1.0
scipy._lib.decorator 4.0.5
scipy.linalg._fblas b'$Revision: $'
scipy.linalg._flapack b'$Revision: $'
scipy.linalg._flinalg b'$Revision: $'
scipy.special._specfun b'$Revision: $'
scipy.sparse.linalg._isolve._iterative b'$Revision: $'
scipy.sparse.linalg._eigen.arpack._arpack b'$Revision: $'
scipy.optimize._minpack2 b'$Revision: $'
scipy.optimize._lbfgsb b'$Revision: $'
scipy.optimize._cobyla b'$Revision: $'
scipy.optimize._slsqp b'$Revision: $'
scipy.optimize._minpack  1.10
scipy.optimize.__nnls b'$Revision: $'
scipy.linalg._interpolative b'$Revision: $'
scipy.integrate._odepack  1.9
scipy.integrate._quadpack  1.13
scipy.integrate._vode b'$Revision: $'
scipy.integrate._dop b'$Revision: $'
scipy.integrate._lsoda b'$Revision: $'
scipy.integrate._ode $Id$
scipy.interpolate._fitpack  1.7
scipy.interpolate.dfitpack b'$Revision: $'
scipy.stats._statlib b'$Revision: $'
scipy.stats._mvn b'$Revision: $'
threadpoolctl 3.1.0
sklearn.base 1.1.0
sklearn.utils._show_versions 1.1.0
sklearn 1.1.0
texttable 1.6.4
igraph.version 0.9.10
igraph 0.9.10
leidenalg.version 0.8.10
leidenalg 0.8.10
matplotlib._version 3.5.2
PIL._version 9.1.1
PIL 9.1.1
defusedxml 0.7.1
xml.etree.ElementTree 1.3.0
cffi 1.15.0
PIL.Image 9.1.1
pyparsing 3.0.4
cycler 0.10.0
kiwisolver._cext 1.4.2
kiwisolver 1.4.2
matplotlib 3.5.2
scanpy 1.9.1
torch.version 1.11.0+cu113
torch.torch_version 1.11.0+cu113
tarfile 0.9.0
torch.cuda.nccl (2, 10, 3)
torch.backends.cudnn 8200
tqdm._dist_ver 4.64.0
tqdm.version 4.64.0
tqdm.cli 4.64.0
tqdm 4.64.0
ipywidgets._version 7.7.0
ipython_genutils._version 0.2.0
ipython_genutils 0.2.0
ipywidgets 7.7.0
torch 1.11.0+cu113
opt_einsum v3.3.0
pyro._version 1.8.1
pyro 1.8.1
pytorch_lightning.__about__ 1.5.10
torchmetrics.__about__ 0.8.2
urllib3.packages.six 1.16.0
urllib3._version 1.26.9
ipaddress 1.0
urllib3.connection 1.26.9
urllib3 1.26.9
charset_normalizer.version 2.0.12
charset_normalizer 2.0.12
requests.__version__ 2.27.1
certifi 2021.10.08
requests.utils 2.27.1
requests.packages.urllib3.packages.six 1.16.0
requests.packages.urllib3._version 1.26.9
requests.packages.urllib3.connection 1.26.9
requests.packages.urllib3 1.26.9
idna.package_data 3.3
idna.idnadata 14.0.0
idna 3.3
requests.packages.idna.package_data 3.3
requests.packages.idna.idnadata 14.0.0
requests.packages.idna 3.3
requests.packages.chardet 2.0.12
requests 2.27.1
torchvision.version 0.12.0+cu113
torchvision 0.12.0+cu113
torchmetrics 0.8.2
fsspec 2022.3.0
attr 21.4.0
tensorboard.version 2.9.0
tensorboard 2.9.0
google.protobuf 3.20.1
tensorboard.compat.tensorflow_stub.pywrap_tensorflow 0
tensorboard.compat.tensorflow_stub stub
pytorch_lightning.loggers.neptune 1.5.10
deprecate 0.3.1
pytorch_lightning 1.5.10
docrep 0.3.2
scipy._lib._uarray 0.8.2+14.gaf53966.scipy
scipy.signal._spline 0.2
jaxlib.version 0.3.10
jaxlib 0.3.10
jax.version 0.3.13
flatbuffers._version 2.0
flatbuffers 2.0
jax.lib 0.3.10
jax 0.3.13
flax.version 0.4.2
flax 0.4.2
multipledispatch 0.6.0
numpyro.version 0.9.2
numpyro 0.9.2
xml.sax.handler 2.0beta
tree 0.1.7
toolz 0.11.2
chex 0.1.3
optax 0.1.2
scvi 0.16.2
_cffi_backend 1.15.0
pycparser.ply 3.9
pycparser.ply.yacc 3.10
pycparser.ply.lex 3.10
pycparser 2.21
pynndescent 0.5.7
umap 0.5.3
fontTools 4.33.3
fontTools.misc.sstruct 1.2
fontTools.ttLib.tables._g_l_y_f 4.33
xml.etree.cElementTree 1.3.0
fontTools.misc.etree 1.3.0
[ ]: