Cell2location: spatial mapping (scvi-tools/pyro)

User-facing cell2location spatial cell abundance estimation model class (scvi-tools BaseModelClass)

class cell2location.models.Cell2location(adata: anndata._core.anndata.AnnData, cell_state_df: pandas.core.frame.DataFrame, model_class: Optional[pyro.nn.module.PyroModule, None] = None, detection_mean_per_sample: bool = False, detection_mean_correction: float = 1.0, **model_kwargs)[source]

Bases: cell2location.models.base._pyro_mixin.QuantileMixin, scvi.model.base._pyromixin.PyroSampleMixin, scvi.model.base._pyromixin.PyroSviTrainMixin, cell2location.models.base._pyro_mixin.PltExportMixin, scvi.model.base._base_model.BaseModelClass

Cell2location model. User-end model class. See Module class for description of the model (incl. math).

Parameters
  • adata – spatial AnnData object that has been registered via setup_anndata().

  • cell_state_df – pd.DataFrame with reference expression signatures for each gene (rows) in each cell type/population (columns).

  • use_gpu – Use the GPU?

  • **model_kwargs – Keyword args for LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel

Examples

TODO add example >>>

classmethod setup_anndata(adata: anndata._core.anndata.AnnData, layer: Optional[str, None] = None, batch_key: Optional[str, None] = None, labels_key: Optional[str, None] = None, categorical_covariate_keys: Optional[List[str], None] = None, continuous_covariate_keys: Optional[List[str], None] = None, **kwargs)[source]
Sets up the AnnData object for this model.

A mapping will be created between data fields used by this model to their respective locations in adata.

None of the data in adata are modified. Only adds fields to adata.

Parameters
  • layer – if not None, uses this as the key in adata.layers for raw count data.

  • batch_key – key in adata.obs for batch information. Categories will automatically be converted into integer categories and saved to adata.obs[‘_scvi_batch’]. If None, assigns the same batch to all the data.

  • labels_key – key in adata.obs for label information. Categories will automatically be converted into integer categories and saved to adata.obs[‘_scvi_labels’]. If None, assigns the same label to all the data.

  • categorical_covariate_keys – keys in adata.obs that correspond to categorical data. These covariates can be added in addition to the batch covariate and are also treated as nuisance factors (i.e., the model tries to minimize their effects on the latent space). Thus, these should not be used for biologically-relevant factors that you do _not_ want to correct for.

  • continuous_covariate_keys – keys in adata.obs that correspond to continuous data. These covariates can be added in addition to the batch covariate and are also treated as nuisance factors (i.e., the model tries to minimize their effects on the latent space). Thus, these should not be used for biologically-relevant factors that you do _not_ want to correct for.

train(max_epochs: int = 30000, batch_size: int = None, train_size: float = 1, lr: float = 0.002, num_particles: int = 1, scale_elbo: float = 1.0, **kwargs)[source]

Train the model with useful defaults

Parameters
  • max_epochs – Number of passes through the dataset. If None, defaults to np.min([round((20000 / n_cells) * 400), 400])

  • train_size – Size of training set in the range [0.0, 1.0]. Use all data points in training because we need to estimate cell abundance at all locations.

  • batch_size – Minibatch size to use during training. If None, no minibatching occurs and all data is copied to device (e.g., GPU).

  • lr – Optimiser learning rate (default optimiser is ClippedAdam). Specifying optimiser via plan_kwargs overrides this choice of lr.

  • kwargs – Other arguments to scvi.model.base.PyroSviTrainMixin().train() method

train_aggressive(max_epochs: Optional[int, None] = 1000, use_gpu: Optional[Union[str, int, bool]] = None, train_size: float = 1, validation_size: Optional[float, None] = None, batch_size: int = None, early_stopping: bool = False, lr: Optional[float, None] = None, plan_kwargs: Optional[dict, None] = None, **trainer_kwargs)[source]

Train the model. :Parameters: * max_epochs – Number of passes through the dataset. If None, defaults to

np.min([round((20000 / n_cells) * 400), 400])

  • use_gpu – Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., ‘cuda:0’), or use CPU (if False).

  • train_size – Size of training set in the range [0.0, 1.0].

  • validation_size – Size of the test set. If None, defaults to 1 - train_size. If train_size + validation_size < 1, the remaining cells belong to a test set.

  • batch_size – Minibatch size to use during training. If None, no minibatching occurs and all data is copied to device (e.g., GPU).

  • early_stopping – Perform early stopping. Additional arguments can be passed in **kwargs. See Trainer for further options.

  • lr – Optimiser learning rate (default optimiser is ClippedAdam). Specifying optimiser via plan_kwargs overrides this choice of lr.

  • plan_kwargs – Keyword args for TrainingPlan. Keyword arguments passed to train() will overwrite values present in plan_kwargs, when appropriate.

  • **trainer_kwargs – Other keyword args for Trainer.

export_posterior(adata, sample_kwargs: Optional[dict, None] = None, export_slot: str = 'mod', add_to_obsm: list = ['means', 'stds', 'q05', 'q95'], use_quantiles: bool = False)[source]

Summarise posterior distribution and export results (cell abundance) to anndata object:

  1. adata.obsm: Estimated cell abundance as pd.DataFrames for each posterior distribution summary add_to_obsm,

    posterior mean, sd, 5% and 95% quantiles ([‘means’, ‘stds’, ‘q05’, ‘q95’]). If export to adata.obsm fails with error, results are saved to adata.obs instead.

  2. adata.uns: Posterior of all parameters, model name, date,

    cell type names (‘factor_names’), obs and var names.

Parameters
  • adata – anndata object where results should be saved

  • sample_kwargs

    arguments for self.sample_posterior (generating and summarising posterior samples), namely:

    num_samples - number of samples to use (Default = 1000). batch_size - data batch size (keep low enough to fit on GPU, default 2048). use_gpu - use gpu for generating samples?

  • export_slot – adata.uns slot where to export results

  • add_to_obsm – posterior distribution summary to export in adata.obsm ([‘means’, ‘stds’, ‘q05’, ‘q95’]).

  • use_quantiles – compute quantiles directly (True, more memory efficient) or use samples (False, default). If True, means and stds cannot be computed so are not exported and returned.

plot_spatial_QC_across_batches()[source]

QC plot: compare total RNA count with estimated total cell abundance and detection sensitivity.

Pyro and scvi-tools Module classes (inc math description)

Pyro Module class (defining the model using pyro, math description)

class cell2location.models._cell2location_module.LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel(n_obs, n_vars, n_factors, n_batch, cell_state_mat, n_groups: int = 50, detection_mean=0.5, detection_alpha=20.0, m_g_gene_level_prior={'alpha_mean': 3.0, 'mean': 1, 'mean_var_ratio': 1.0}, N_cells_per_location=8.0, A_factors_per_location=7.0, B_groups_per_location=7.0, N_cells_mean_var_ratio=1.0, alpha_g_phi_hyp_prior={'alpha': 9.0, 'beta': 3.0}, gene_add_alpha_hyp_prior={'alpha': 9.0, 'beta': 3.0}, gene_add_mean_hyp_prior={'alpha': 1.0, 'beta': 100.0}, detection_hyp_prior={'mean_alpha': 10.0}, w_sf_mean_var_ratio=5.0, init_vals: Optional[dict, None] = None, init_alpha=20.0, dropout_p=0.0)[source]

Bases: pyro.nn.module.PyroModule

Cell2location models the elements of \(D\) as Negative Binomial distributed, given an unobserved gene expression level (rate) \(mu\) and a gene- and batch-specific over-dispersion parameter \(\alpha_{e,g}\) which accounts for unexplained variance:

\[D_{s,g} \sim \mathtt{NB}(\mu_{s,g}, \alpha_{e,g})\]

The expression level of genes \(\mu_{s,g}\) in the mRNA count space is modelled as a linear function of expression signatures of reference cell types \(g_{f,g}\):

\[\mu_{s,g} = (m_{g} \left (\sum_{f} {w_{s,f} \: g_{f,g}} \right) + s_{e,g}) y_{s}\]

Here, \(w_{s,f}\) denotes regression weight of each reference signature \(f\) at location \(s\), which can be interpreted as the expected number of cells at location \(s\) that express reference signature \(f\); \(g_{f,g}\) denotes the reference signatures of cell types \(f\) of each gene \(g\), cell_state_df input ; \(m_{g}\) denotes a gene-specific scaling parameter which adjusts for global differences in sensitivity between technologies (platform effect); \(y_{s}\) denotes a location/observation-specific scaling parameter which adjusts for differences in sensitivity between observations and batches; \(s_{e,g}\) is additive component that account for gene- and location-specific shift, such as due to contaminating or free-floating RNA.

To account for the similarity of location patterns across cell types, \(w_{s,f}\) is modelled using another layer of decomposition (factorization) using \(r={1, .., R}\) groups of cell types, that can be interpreted as cellular compartments or tissue zones. Unless stated otherwise, R is set to 50.

Corresponding graphical model can be found in supplementary methods: https://www.biorxiv.org/content/10.1101/2020.11.15.378125v1.supplementary-material

Approximate Variational Inference is used to estimate the posterior distribution of all model parameters.

Estimation of absolute cell abundance \(w_{s,f}\) is guided using informed prior on the number of cells (argument called N_cells_per_location). It is a tissue-level global estimate, which can be derived from histology images (H&E or DAPI), ideally paired to the spatial expression data or at least representing the same tissue type. This parameter can be estimated by manually counting nuclei in a 10-20 locations in the histology image (e.g. using 10X Loupe browser), and computing the average cell abundance. An appropriate setting of this prior is essential to inform the estimation of absolute cell type abundance values, however, the model is robust to a range of similar values. In settings where suitable histology images are not available, the size of capture regions relative to the expected size of cells can be used to estimate N_cells_per_location.

The prior on detection efficiency per location \(y_s\) is selected to discourage over-normalisation, such that unless data has evidence of strong technical effect, the effect is assumed to be small and close to the mean sensitivity for each batch \(y_e\):

\[y_s \sim Gamma(detection\_alpha, detection\_alpha / y_e)\]

where y_e is unknown/latent average detection efficiency in each batch/experiment:

\[y_e \sim Gamma(10, 10 / detection\_mean)\]
training_wo_observed = False
training_wo_initial = False
create_plates(x_data, idx, batch_index)[source]
list_obs_plate_vars()[source]

Create a dictionary with:

  1. “name” - the name of observation/minibatch plate;

  2. “input” - indexes of model args to provide to encoder network when using amortised inference;

  3. “sites” - dictionary with

  • keys - names of variables that belong to the observation plate (used to recognise and merge posterior samples for minibatch variables)

  • values - the dimensions in non-plate axis of each variable (used to construct output layer of encoder network when using amortised inference)

forward(x_data, idx, batch_index)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

compute_expected(samples, adata_manager, ind_x=None, hide_ambient=False, hide_cell_type=False)[source]

Compute expected expression of each gene in each location. Useful for evaluating how well the model learned expression pattern of all genes in the data.

compute_expected_per_cell_type(samples, adata_manager, ind_x=None)[source]

Compute expected expression of each gene in each location for each cell type.

Parameters
  • samples – Posterior distribution summary self.samples[f”post_sample_q05}”] (or ‘means’, ‘stds’, ‘q05’, ‘q95’) produced by export_posterior().

  • ind_x – Location/observation indices for which to compute expected count (if None all locations are used).

Returns

dictionary with:

  1. list with expected expression counts (sparse, shape=(N locations, N genes) for each cell type in the same order as mod.factor_names_;

  2. np.array with location indices

Return type

dict

Simplified model architectures

No prior factorisation of w_sf (Pyro Module class, math description)

class cell2location.models.simplified._cell2location_v3_no_factorisation_module.LocationModelMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel(n_obs, n_vars, n_factors, n_batch, cell_state_mat, n_groups: int = 50, detection_mean=0.5, detection_alpha=200.0, m_g_gene_level_prior={'alpha_mean': 3.0, 'mean': 1, 'mean_var_ratio': 1.0}, N_cells_per_location=8.0, A_factors_per_location=7.0, N_cells_mean_var_ratio=1.0, alpha_g_phi_hyp_prior={'alpha': 9.0, 'beta': 3.0}, gene_add_alpha_hyp_prior={'alpha': 9.0, 'beta': 3.0}, gene_add_mean_hyp_prior={'alpha': 1.0, 'beta': 100.0}, detection_hyp_prior={'mean_alpha': 10.0}, w_sf_mean_var_ratio=5.0)[source]

Bases: pyro.nn.module.PyroModule

Cell2location models the elements of \(D\) as Negative Binomial distributed, given an unobserved gene expression level (rate) \(mu\) and a gene- and batch-specific over-dispersion parameter \(\alpha_{e,g}\) which accounts for unexplained variance:

\[D_{s,g} \sim \mathtt{NB}(\mu_{s,g}, \alpha_{e,g})\]

The expression level of genes \(\mu_{s,g}\) in the mRNA count space is modelled as a linear function of expression signatures of reference cell types \(g_{f,g}\):

\[\mu_{s,g} = (m_{g} \left (\sum_{f} {w_{s,f} \: g_{f,g}} \right) + s_{e,g}) y_{s}\]

Here, \(w_{s,f}\) denotes regression weight of each reference signature \(f\) at location \(s\), which can be interpreted as the expected number of cells at location \(s\) that express reference signature \(f\); \(g_{f,g}\) denotes the reference signatures of cell types \(f\) of each gene \(g\), cell_state_df input ; \(m_{g}\) denotes a gene-specific scaling parameter which adjusts for global differences in sensitivity between technologies (platform effect); \(y_{s}\) denotes a location/observation-specific scaling parameter which adjusts for differences in sensitivity between observations and batches; \(s_{e,g}\) is additive component that account for gene- and location-specific shift, such as due to contaminating or free-floating RNA.

To account for the similarity of location patterns across cell types, \(w_{s,f}\) is modelled using another layer of decomposition (factorization) using \(r={1, .., R}\) groups of cell types, that can be interpreted as cellular compartments or tissue zones. Unless stated otherwise, R is set to 50.

Corresponding graphical model can be found in supplementary methods: https://www.biorxiv.org/content/10.1101/2020.11.15.378125v1.supplementary-material

Approximate Variational Inference is used to estimate the posterior distribution of all model parameters.

Estimation of absolute cell abundance \(w_{s,f}\) is guided using informed prior on the number of cells (argument called N_cells_per_location). It is a tissue-level global estimate, which can be derived from histology images (H&E or DAPI), ideally paired to the spatial expression data or at least representing the same tissue type. This parameter can be estimated by manually counting nuclei in a 10-20 locations in the histology image (e.g. using 10X Loupe browser), and computing the average cell abundance. An appropriate setting of this prior is essential to inform the estimation of absolute cell type abundance values, however, the model is robust to a range of similar values. In settings where suitable histology images are not available, the size of capture regions relative to the expected size of cells can be used to estimate N_cells_per_location.

The prior on detection efficiency per location \(y_s\) is selected to discourage over-normalisation, such that unless data has evidence of strong technical effect, the effect is assumed to be small and close to the mean sensitivity for each batch \(y_e\):

\[y_s \sim Gamma(detection\_alpha, detection\_alpha / y_e)\]

where y_e is unknown/latent average detection efficiency in each batch/experiment:

\[y_e \sim Gamma(10, 10 / detection\_mean)\]
create_plates(x_data, idx, batch_index)[source]
list_obs_plate_vars()[source]

Create a dictionary with:

  1. “name” - the name of observation/minibatch plate;

  2. “input” - indexes of model args to provide to encoder network when using amortised inference;

  3. “sites” - dictionary with

  • keys - names of variables that belong to the observation plate (used to recognise and merge posterior samples for minibatch variables)

  • values - the dimensions in non-plate axis of each variable (used to construct output layer of encoder network when using amortised inference)

forward(x_data, idx, batch_index)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

compute_expected(samples, adata_manager, ind_x=None)[source]

Compute expected expression of each gene in each location. Useful for evaluating how well the model learned expression pattern of all genes in the data.

No gene-specific platform effect m_g (Pyro Module class, math description)

class cell2location.models.simplified._cell2location_v3_no_mg_module.LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelNoMGPyroModel(n_obs, n_vars, n_factors, n_batch, cell_state_mat, n_groups: int = 50, detection_mean=0.5, detection_alpha=200.0, N_cells_per_location=8.0, A_factors_per_location=7.0, Y_groups_per_location=7.0, N_cells_mean_var_ratio=1.0, alpha_g_phi_hyp_prior={'alpha': 9.0, 'beta': 3.0}, gene_add_alpha_hyp_prior={'alpha': 9.0, 'beta': 3.0}, gene_add_mean_hyp_prior={'alpha': 1.0, 'beta': 100.0}, detection_hyp_prior={'mean_alpha': 10.0}, w_sf_mean_var_ratio=5.0)[source]

Bases: pyro.nn.module.PyroModule

Cell2location models the elements of \(D\) as Negative Binomial distributed, given an unobserved gene expression level (rate) \(mu\) and a gene- and batch-specific over-dispersion parameter \(\alpha_{e,g}\) which accounts for unexplained variance:

\[D_{s,g} \sim \mathtt{NB}(\mu_{s,g}, \alpha_{e,g})\]

The expression level of genes \(\mu_{s,g}\) in the mRNA count space is modelled as a linear function of expression signatures of reference cell types \(g_{f,g}\):

\[\mu_{s,g} = (\left (\sum_{f} {w_{s,f} \: g_{f,g}} \right) + s_{e,g}) y_{s}\]

Here, \(w_{s,f}\) denotes regression weight of each reference signature \(f\) at location \(s\), which can be interpreted as the expected number of cells at location \(s\) that express reference signature \(f\); \(g_{f,g}\) denotes the reference signatures of cell types \(f\) of each gene \(g\), cell_state_df input ; \(y_{s}\) denotes a location/observation-specific scaling parameter which adjusts for differences in sensitivity between observations and batches; \(s_{e,g}\) is additive component that account for gene- and location-specific shift, such as due to contaminating or free-floating RNA.

To account for the similarity of location patterns across cell types, \(w_{s,f}\) is modelled using another layer of decomposition (factorization) using \(r={1, .., R}\) groups of cell types, that can be interpreted as cellular compartments or tissue zones. Unless stated otherwise, R is set to 50.

Corresponding graphical model can be found in supplementary methods: https://www.biorxiv.org/content/10.1101/2020.11.15.378125v1.supplementary-material

Approximate Variational Inference is used to estimate the posterior distribution of all model parameters.

Estimation of absolute cell abundance \(w_{s,f}\) is guided using informed prior on the number of cells (argument called N_cells_per_location). It is a tissue-level global estimate, which can be derived from histology images (H&E or DAPI), ideally paired to the spatial expression data or at least representing the same tissue type. This parameter can be estimated by manually counting nuclei in a 10-20 locations in the histology image (e.g. using 10X Loupe browser), and computing the average cell abundance. An appropriate setting of this prior is essential to inform the estimation of absolute cell type abundance values, however, the model is robust to a range of similar values. In settings where suitable histology images are not available, the size of capture regions relative to the expected size of cells can be used to estimate N_cells_per_location.

The prior on detection efficiency per location \(y_s\) is selected to discourage over-normalisation, such that unless data has evidence of strong technical effect, the effect is assumed to be small and close to the mean sensitivity for each batch \(y_e\):

\[y_s ~ Gamma(detection\_alpha, detection\_alpha / y_e)\]

where y_e is unknown/latent average detection efficiency in each batch/experiment:

\[y_e ~ Gamma(10, 10 / detection\_mean)\]
create_plates(x_data, idx, batch_index)[source]
list_obs_plate_vars()[source]

Create a dictionary with:

  1. “name” - the name of observation/minibatch plate;

  2. “input” - indexes of model args to provide to encoder network when using amortised inference;

  3. “sites” - dictionary with

  • keys - names of variables that belong to the observation plate (used to recognise and merge posterior samples for minibatch variables)

  • values - the dimensions in non-plate axis of each variable (used to construct output layer of encoder network when using amortised inference)

forward(x_data, idx, batch_index)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

compute_expected(samples, adata_manager, ind_x=None)[source]

Compute expected expression of each gene in each location. Useful for evaluating how well the model learned expression pattern of all genes in the data.