Reference signatures (NB regression)

User-facing reference signature estimation model class (scvi-tools BaseModelClass)

class cell2location.models.RegressionModel(adata: anndata._core.anndata.AnnData, model_class=None, use_average_as_initial: bool = True, **model_kwargs)[source]

Bases: cell2location.models.base._pyro_mixin.QuantileMixin, scvi.model.base._pyromixin.PyroSampleMixin, scvi.model.base._pyromixin.PyroSviTrainMixin, cell2location.models.base._pyro_mixin.PltExportMixin, scvi.model.base._base_model.BaseModelClass

Model which estimates per cluster average mRNA count account for batch effects. User-end model class.

https://github.com/BayraktarLab/cell2location

Parameters
  • adata – single-cell AnnData object that has been registered via setup_anndata().

  • use_gpu – Use the GPU?

  • **model_kwargs – Keyword args for LocationModelLinearDependentWMultiExperimentModel

Examples

TODO add example >>>

classmethod setup_anndata(adata: anndata._core.anndata.AnnData, layer: Optional[str, None] = None, batch_key: Optional[str, None] = None, labels_key: Optional[str, None] = None, categorical_covariate_keys: Optional[List[str], None] = None, continuous_covariate_keys: Optional[List[str], None] = None, **kwargs)[source]

Sets up the AnnData object for this model.

A mapping will be created between data fields used by this model to their respective locations in adata. None of the data in adata are modified. Only adds fields to adata.

Parameters
  • layer – if not None, uses this as the key in adata.layers for raw count data.

  • batch_key – key in adata.obs for batch information. Categories will automatically be converted into integer categories and saved to adata.obs[‘_scvi_batch’]. If None, assigns the same batch to all the data.

  • labels_key – key in adata.obs for label information. Categories will automatically be converted into integer categories and saved to adata.obs[‘_scvi_labels’]. If None, assigns the same label to all the data.

  • categorical_covariate_keys – keys in adata.obs that correspond to categorical data. These covariates can be added in addition to the batch covariate and are also treated as nuisance factors (i.e., the model tries to minimize their effects on the latent space). Thus, these should not be used for biologically-relevant factors that you do _not_ want to correct for.

  • continuous_covariate_keys – keys in adata.obs that correspond to continuous data. These covariates can be added in addition to the batch covariate and are also treated as nuisance factors (i.e., the model tries to minimize their effects on the latent space). Thus, these should not be used for biologically-relevant factors that you do _not_ want to correct for.

train(max_epochs: Optional[int, None] = None, batch_size: int = 2500, train_size: float = 1, lr: float = 0.002, **kwargs)[source]

Train the model with useful defaults

Parameters
  • max_epochs – Number of passes through the dataset. If None, defaults to np.min([round((20000 / n_cells) * 400), 400])

  • train_size – Size of training set in the range [0.0, 1.0].

  • batch_size – Minibatch size to use during training. If None, no minibatching occurs and all data is copied to device (e.g., GPU).

  • lr – Optimiser learning rate (default optimiser is ClippedAdam). Specifying optimiser via plan_kwargs overrides this choice of lr.

  • kwargs – Other arguments to scvi.model.base.PyroSviTrainMixin().train() method

export_posterior(adata, sample_kwargs: Optional[dict, None] = None, export_slot: str = 'mod', add_to_varm: list = ['means', 'stds', 'q05', 'q95'], scale_average_detection: bool = True, use_quantiles: bool = False)[source]

Summarise posterior distribution and export results (cell abundance) to anndata object: 1. adata.obsm: Estimated references expression signatures (average mRNA count in each cell type),

as pd.DataFrames for each posterior distribution summary add_to_varm, posterior mean, sd, 5% and 95% quantiles ([‘means’, ‘stds’, ‘q05’, ‘q95’]). If export to adata.varm fails with error, results are saved to adata.var instead.

  1. adata.uns: Posterior of all parameters, model name, date,

    cell type names (‘factor_names’), obs and var names.

Parameters
  • adata – anndata object where results should be saved

  • sample_kwargs

    arguments for self.sample_posterior (generating and summarising posterior samples), namely:

    num_samples - number of samples to use (Default = 1000). batch_size - data batch size (keep low enough to fit on GPU, default 2048). use_gpu - use gpu for generating samples?

  • export_slot – adata.uns slot where to export results

  • add_to_varm – posterior distribution summary to export in adata.varm ([‘means’, ‘stds’, ‘q05’, ‘q95’]).

  • use_quantiles – compute quantiles directly (True, more memory efficient) or use samples (False, default). If True, means and stds cannot be computed so are not exported and returned.

plot_QC(summary_name: str = 'means', use_n_obs: int = 1000, scale_average_detection: bool = True)[source]

Show quality control plots: 1. Reconstruction accuracy to assess if there are any issues with model training.

The plot should be roughly diagonal, strong deviations signal problems that need to be investigated. Plotting is slow because expected value of mRNA count needs to be computed from model parameters. Random observations are used to speed up computation.

  1. Estimated reference expression signatures (accounting for batch effect)

    compared to average expression in each cluster. We expect the signatures to be different from average when batch effects are present, however, when this plot is very different from a perfect diagonal, such as very low values on Y-axis, non-zero density everywhere) it indicates problems with signature estimation.

Parameters

summary_name – posterior distribution summary to use (‘means’, ‘stds’, ‘q05’, ‘q95’)

Pyro and scvi-tools Module classes

Pyro Module class (defining the model using pyro)

class cell2location.models.reference._reference_module.RegressionBackgroundDetectionTechPyroModel(n_obs, n_vars, n_factors, n_batch, n_extra_categoricals=None, alpha_g_phi_hyp_prior={'alpha': 9.0, 'beta': 3.0}, gene_add_alpha_hyp_prior={'alpha': 9.0, 'beta': 3.0}, gene_add_mean_hyp_prior={'alpha': 1.0, 'beta': 100.0}, detection_hyp_prior={'mean_alpha': 1.0, 'mean_beta': 1.0}, gene_tech_prior={'alpha': 200, 'mean': 1}, init_vals: Optional[dict, None] = None)[source]

Bases: pyro.nn.module.PyroModule

Given cell type annotation for each cell, the corresponding reference cell type signatures \(g_{f,g}\), which represent the average mRNA count of each gene g in each cell type f={1, .., F}, are estimated from sc/snRNA-seq data using Negative Binomial regression, which allows to robustly combine data across technologies and batches.

This model combines batches, and treats data \(D\) as Negative Binomial distributed, given mean \(\mu\) and overdispersion \(\alpha\):

\[D_{c,g} \sim \mathtt{NB}(alpha=\alpha_{g}, mu=\mu_{c,g})\]
\[\mu_{c,g} = (\mu_{f,g} + s_{e,g}) * y_e * y_{t,g}\]

Which is equivalent to:

\[D_{c,g} \sim \mathtt{Poisson}(\mathtt{Gamma}(\alpha_{f,g}, \alpha_{f,g} / \mu_{c,g}))\]

Here, \(\mu_{f,g}\) denotes average mRNA count in each cell type \(f\) for each gene \(g\); \(y_c\) denotes normalisation for each experiment \(e\) to account for sequencing depth. \(y_{t,g}\) denotes per gene \(g\) detection efficiency normalisation for each technology \(t\).

create_plates(x_data, idx, batch_index, label_index, extra_categoricals)[source]
list_obs_plate_vars()[source]

Create a dictionary with the name of observation/minibatch plate, indexes of model args to provide to encoder, variable names that belong to the observation plate and the number of dimensions in non-plate axis of each variable

forward(x_data, idx, batch_index, label_index, extra_categoricals)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

compute_expected(samples, adata_manager, ind_x=None)[source]

Compute expected expression of each gene in each cell. Useful for evaluating how well the model learned expression pattern of all genes in the data.

Parameters
  • samples – dictionary with values of the posterior

  • adata – registered anndata

  • ind_x – indices of cells to use (to reduce data size)

compute_expected_subset(samples, adata_manager, fact_ind, cell_ind)[source]

Compute expected expression of each gene in each cell that comes from a subset of factors (cell types) or cells.

Useful for evaluating how well the model learned expression pattern of all genes in the data.

Parameters
  • samples – dictionary with values of the posterior

  • adata – registered anndata

  • fact_ind – indices of factors/cell types to use

  • cell_ind – indices of cells to use

normalise(samples, adata_manager, adata)[source]

Normalise expression data by estimated technical variables.

Parameters
  • samples – dictionary with values of the posterior

  • adata – registered anndata

scvi-tools Module class (initialising the model and the guide, PyroBaseModuleClass)

class cell2location.models.reference._reference_module.RegressionBackgroundDetectionTechPyroModel(n_obs, n_vars, n_factors, n_batch, n_extra_categoricals=None, alpha_g_phi_hyp_prior={'alpha': 9.0, 'beta': 3.0}, gene_add_alpha_hyp_prior={'alpha': 9.0, 'beta': 3.0}, gene_add_mean_hyp_prior={'alpha': 1.0, 'beta': 100.0}, detection_hyp_prior={'mean_alpha': 1.0, 'mean_beta': 1.0}, gene_tech_prior={'alpha': 200, 'mean': 1}, init_vals: Optional[dict, None] = None)[source]

Bases: pyro.nn.module.PyroModule

Given cell type annotation for each cell, the corresponding reference cell type signatures \(g_{f,g}\), which represent the average mRNA count of each gene g in each cell type f={1, .., F}, are estimated from sc/snRNA-seq data using Negative Binomial regression, which allows to robustly combine data across technologies and batches.

This model combines batches, and treats data \(D\) as Negative Binomial distributed, given mean \(\mu\) and overdispersion \(\alpha\):

\[D_{c,g} \sim \mathtt{NB}(alpha=\alpha_{g}, mu=\mu_{c,g})\]
\[\mu_{c,g} = (\mu_{f,g} + s_{e,g}) * y_e * y_{t,g}\]

Which is equivalent to:

\[D_{c,g} \sim \mathtt{Poisson}(\mathtt{Gamma}(\alpha_{f,g}, \alpha_{f,g} / \mu_{c,g}))\]

Here, \(\mu_{f,g}\) denotes average mRNA count in each cell type \(f\) for each gene \(g\); \(y_c\) denotes normalisation for each experiment \(e\) to account for sequencing depth. \(y_{t,g}\) denotes per gene \(g\) detection efficiency normalisation for each technology \(t\).

create_plates(x_data, idx, batch_index, label_index, extra_categoricals)[source]
list_obs_plate_vars()[source]

Create a dictionary with the name of observation/minibatch plate, indexes of model args to provide to encoder, variable names that belong to the observation plate and the number of dimensions in non-plate axis of each variable

forward(x_data, idx, batch_index, label_index, extra_categoricals)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

compute_expected(samples, adata_manager, ind_x=None)[source]

Compute expected expression of each gene in each cell. Useful for evaluating how well the model learned expression pattern of all genes in the data.

Parameters
  • samples – dictionary with values of the posterior

  • adata – registered anndata

  • ind_x – indices of cells to use (to reduce data size)

compute_expected_subset(samples, adata_manager, fact_ind, cell_ind)[source]

Compute expected expression of each gene in each cell that comes from a subset of factors (cell types) or cells.

Useful for evaluating how well the model learned expression pattern of all genes in the data.

Parameters
  • samples – dictionary with values of the posterior

  • adata – registered anndata

  • fact_ind – indices of factors/cell types to use

  • cell_ind – indices of cells to use

normalise(samples, adata_manager, adata)[source]

Normalise expression data by estimated technical variables.

Parameters
  • samples – dictionary with values of the posterior

  • adata – registered anndata