Pyro and scvi-tools infrastructure classes

Base mixin classes (AutoGuide setup, posterior quantile computation, plotting & export)

cell2location.models.base._pyro_mixin.init_to_value(site=None, values={}, init_fn=<function init_to_mean>)[source]
class cell2location.models.base._pyro_mixin.AutoGuideMixinModule[source]

Bases: object

This mixin class provides methods for:

  • initialising standard AutoNormal guides

  • initialising amortised guides (AutoNormalEncoder)

  • initialising amortised guides with special additional inputs

class cell2location.models.base._pyro_mixin.QuantileMixin[source]

Bases: object

This mixin class provides methods for:

  • computing median and quantiles of the posterior distribution using both direct and amortised inference

posterior_quantile(exclude_vars: list = None, batch_size: int = None, **kwargs)[source]

Compute median of the posterior distribution of each parameter.

Parameters
  • q – Quantile to compute

  • use_gpu – Bool, use gpu?

  • use_median – Bool, when q=0.5 use median rather than quantile method of the guide

class cell2location.models.base._pyro_mixin.PltExportMixin[source]

Bases: object

This mixing class provides methods for common plotting tasks and data export.

static plot_posterior_mu_vs_data(mu, data)[source]

Plot expected value of the model (e.g. mean of NB distribution) vs observed data

Parameters
  • mu – expected value

  • data – data value

plot_history(iter_start=0, iter_end=- 1, ax=None)[source]

Plot training history :Parameters: * iter_start – omit initial iterations from the plot

  • iter_end – omit last iterations from the plot

  • ax – matplotlib axis

sample2df_obs(samples: dict, site_name: str = 'w_sf', summary_name: str = 'means', name_prefix: str = 'cell_abundance', factor_names_key: str = '')[source]

Export posterior distribution summary for observation-specific parameters (e.g. spatial cell abundance) as Pandas data frame (means, 5%/95% quantiles or sd of posterior distribution).

Parameters
  • samples – dictionary with posterior mean, 5%/95% quantiles, SD, samples, generated by .sample_posterior()

  • site_name – name of the model parameter to be exported

  • summary_name – posterior distribution summary to return [‘means’, ‘stds’, ‘q05’, ‘q95’]

  • name_prefix – prefix to add to column names (f’{summary_name}{name_prefix}_{site_name}_{self.factor_names_}’)

Returns

Return type

Pandas data frame corresponding to either means, 5%/95% quantiles or sd of the posterior distribution

sample2df_vars(samples: dict, site_name: str = 'gene_factors', summary_name: str = 'means', name_prefix: str = '', factor_names_key: str = '')[source]

Export posterior distribution summary for variable-specific parameters as Pandas data frame (means, 5%/95% quantiles or sd of posterior distribution).

Parameters
  • samples – dictionary with posterior mean, 5%/95% quantiles, SD, samples, generated by .sample_posterior()

  • site_name – name of the model parameter to be exported

  • summary_name – posterior distribution summary to return (‘means’, ‘stds’, ‘q05’, ‘q95’)

  • name_prefix – prefix to add to column names (f’{summary_name}{name_prefix}_{site_name}_{self.factor_names_}’)

Returns

Return type

Pandas data frame corresponding to either means, 5%/95% quantiles or sd of the posterior distribution

plot_QC(summary_name: str = 'means', use_n_obs: int = 1000)[source]

Show quality control plots:

  1. Reconstruction accuracy to assess if there are any issues with model training. The plot should be roughly diagonal, strong deviations signal problems that need to be investigated. Plotting is slow because expected value of mRNA count needs to be computed from model parameters. Random observations are used to speed up computation.

Parameters

summary_name – posterior distribution summary to use (‘means’, ‘stds’, ‘q05’, ‘q95’)

class cell2location.models.base._pyro_mixin.PyroAggressiveConvergence(dataloader: scvi.dataloaders._ann_dataloader.AnnDataLoader = None, patience: int = 10, tolerance: float = 0.0001)[source]

Bases: pytorch_lightning.callbacks.callback.Callback

A callback to compute/apply aggressive training convergence criteria for amortised inference. Motivated by this paper: https://arxiv.org/pdf/1901.05534.pdf

on_train_epoch_end(trainer: pl.Trainer, pl_module: pl.LightningModule, unused: Optional = None) → None[source]

Compute aggressive training convergence criteria for amortised inference.

class cell2location.models.base._pyro_mixin.PyroTrainingPlan(pyro_module: scvi.module.base._base_module.PyroBaseModuleClass, loss_fn: Optional[pyro.infer.elbo.ELBO, None] = None, optim: Optional[pyro.optim.optim.PyroOptim, None] = None, optim_kwargs: Optional[dict, None] = None, n_steps_kl_warmup: Optional[int, None] = None, n_epochs_kl_warmup: Optional[int, None] = 400, scale_elbo: float = 1.0)[source]

Bases: scvi.train._trainingplans.PyroTrainingPlan

training_epoch_end(outputs)[source]

Training epoch end for Pyro training.

class cell2location.models.base._pyro_mixin.PyroAggressiveTrainingPlan1(pyro_module: scvi.module.base._base_module.PyroBaseModuleClass, loss_fn: Optional[pyro.infer.elbo.ELBO, None] = None, optim: Optional[pyro.optim.optim.PyroOptim, None] = None, optim_kwargs: Optional[dict, None] = None, n_aggressive_epochs: int = 1000, n_aggressive_steps: int = 20, n_steps_kl_warmup: Optional[int, None] = None, n_epochs_kl_warmup: Optional[int, None] = 400, aggressive_vars: Optional[list, None] = None, invert_aggressive_selection: bool = False)[source]

Bases: scvi.train._trainingplans.PyroTrainingPlan

Lightning module task to train Pyro scvi-tools modules. :Parameters: * pyro_module – An instance of PyroBaseModuleClass. This object

should have callable model and guide attributes or methods.

  • loss_fn – A Pyro loss. Should be a subclass of ELBO. If None, defaults to Trace_ELBO.

  • optim – A Pyro optimizer instance, e.g., Adam. If None, defaults to pyro.optim.Adam optimizer with a learning rate of 1e-3.

  • optim_kwargs – Keyword arguments for default optimiser pyro.optim.Adam.

  • n_aggressive_epochs – Number of epochs in aggressive optimisation of amortised variables.

  • n_aggressive_steps – Number of steps to spend optimising amortised variables before one step optimising global variables.

  • n_steps_kl_warmup – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.

  • n_epochs_kl_warmup – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.

change_requires_grad(aggressive_vars_status, non_aggressive_vars_status)[source]
training_epoch_end(outputs)[source]

Training epoch end for Pyro training.

training_step(batch, batch_idx)[source]

Training step for Pyro training.

class cell2location.models.base._pyro_mixin.PyroAggressiveTrainingPlan(scale_elbo: Optional[float, None] = 1.0, **kwargs)[source]

Bases: cell2location.models.base._pyro_mixin.PyroAggressiveTrainingPlan1

Lightning module task to train Pyro scvi-tools modules. :Parameters: * pyro_module – An instance of PyroBaseModuleClass. This object

should have callable model and guide attributes or methods.

  • loss_fn – A Pyro loss. Should be a subclass of ELBO. If None, defaults to Trace_ELBO.

  • optim – A Pyro optimizer instance, e.g., Adam. If None, defaults to pyro.optim.Adam optimizer with a learning rate of 1e-3.

  • optim_kwargs – Keyword arguments for default optimiser pyro.optim.Adam.

  • n_steps_kl_warmup – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.

  • n_epochs_kl_warmup – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.

scvi-tools Module classes (initialising the model and the guide, PyroBaseModuleClass)

Cell2location spatial cell abundance estimation

class cell2location.models.base._pyro_base_loc_module.Cell2locationBaseModule(model, amortised: bool = False, encoder_mode: Literal[single, multiple, single - multiple] = 'single', encoder_kwargs: Optional[dict, None] = None, data_transform='log1p', create_autoguide_kwargs: Optional[dict, None] = None, **kwargs)[source]

Bases: scvi.module.base._base_module.PyroBaseModuleClass, cell2location.models.base._pyro_mixin.AutoGuideMixinModule

Module class which defines AutoGuide given model. Supports multiple model architectures.

Parameters
  • amortised – boolean, use a Neural Network to approximate posterior distribution of location-specific (local) parameters?

  • encoder_mode – Use single encoder for all variables (“single”), one encoder per variable (“multiple”) or a single encoder in the first step and multiple encoders in the second step (“single-multiple”).

  • encoder_kwargs – arguments for Neural Network construction (scvi.nn.FCLayers)

  • kwargs – arguments for specific model class - e.g. number of genes, values of the prior distribution

property model
property guide
property list_obs_plate_vars

Create a dictionary with:

  1. “name” - the name of observation/minibatch plate;

  2. “input” - indexes of model args to provide to encoder network when using amortised inference;

  3. “sites” - dictionary with

  • keys - names of variables that belong to the observation plate (used to recognise and merge posterior samples for minibatch variables)

  • values - the dimensions in non-plate axis of each variable (used to construct output layer of encoder network when using amortised inference)

property is_amortised
init_to_value(site)[source]

Reference signature estimation

class cell2location.models.reference._reference_model.RegressionModel(adata: anndata._core.anndata.AnnData, model_class=None, use_average_as_initial: bool = True, **model_kwargs)[source]

Bases: cell2location.models.base._pyro_mixin.QuantileMixin, scvi.model.base._pyromixin.PyroSampleMixin, scvi.model.base._pyromixin.PyroSviTrainMixin, cell2location.models.base._pyro_mixin.PltExportMixin, scvi.model.base._base_model.BaseModelClass

Model which estimates per cluster average mRNA count account for batch effects. User-end model class.

https://github.com/BayraktarLab/cell2location

Parameters
  • adata – single-cell AnnData object that has been registered via setup_anndata().

  • use_gpu – Use the GPU?

  • **model_kwargs – Keyword args for LocationModelLinearDependentWMultiExperimentModel

Examples

TODO add example >>>

classmethod setup_anndata(adata: anndata._core.anndata.AnnData, layer: Optional[str, None] = None, batch_key: Optional[str, None] = None, labels_key: Optional[str, None] = None, categorical_covariate_keys: Optional[List[str], None] = None, continuous_covariate_keys: Optional[List[str], None] = None, **kwargs)[source]
Sets up the AnnData object for this model.

A mapping will be created between data fields used by this model to their respective locations in adata.

None of the data in adata are modified. Only adds fields to adata.

Parameters
  • layer – if not None, uses this as the key in adata.layers for raw count data.

  • batch_key – key in adata.obs for batch information. Categories will automatically be converted into integer categories and saved to adata.obs[‘_scvi_batch’]. If None, assigns the same batch to all the data.

  • labels_key – key in adata.obs for label information. Categories will automatically be converted into integer categories and saved to adata.obs[‘_scvi_labels’]. If None, assigns the same label to all the data.

  • categorical_covariate_keys – keys in adata.obs that correspond to categorical data. These covariates can be added in addition to the batch covariate and are also treated as nuisance factors (i.e., the model tries to minimize their effects on the latent space). Thus, these should not be used for biologically-relevant factors that you do _not_ want to correct for.

  • continuous_covariate_keys – keys in adata.obs that correspond to continuous data. These covariates can be added in addition to the batch covariate and are also treated as nuisance factors (i.e., the model tries to minimize their effects on the latent space). Thus, these should not be used for biologically-relevant factors that you do _not_ want to correct for.

train(max_epochs: Optional[int, None] = None, batch_size: int = 2500, train_size: float = 1, lr: float = 0.002, **kwargs)[source]

Train the model with useful defaults

Parameters
  • max_epochs – Number of passes through the dataset. If None, defaults to np.min([round((20000 / n_cells) * 400), 400])

  • train_size – Size of training set in the range [0.0, 1.0].

  • batch_size – Minibatch size to use during training. If None, no minibatching occurs and all data is copied to device (e.g., GPU).

  • lr – Optimiser learning rate (default optimiser is ClippedAdam). Specifying optimiser via plan_kwargs overrides this choice of lr.

  • kwargs – Other arguments to scvi.model.base.PyroSviTrainMixin().train() method

export_posterior(adata, sample_kwargs: Optional[dict, None] = None, export_slot: str = 'mod', add_to_varm: list = ['means', 'stds', 'q05', 'q95'], scale_average_detection: bool = True, use_quantiles: bool = False)[source]

Summarise posterior distribution and export results (cell abundance) to anndata object: 1. adata.obsm: Estimated references expression signatures (average mRNA count in each cell type),

as pd.DataFrames for each posterior distribution summary add_to_varm, posterior mean, sd, 5% and 95% quantiles ([‘means’, ‘stds’, ‘q05’, ‘q95’]). If export to adata.varm fails with error, results are saved to adata.var instead.

  1. adata.uns: Posterior of all parameters, model name, date,

    cell type names (‘factor_names’), obs and var names.

Parameters
  • adata – anndata object where results should be saved

  • sample_kwargs

    arguments for self.sample_posterior (generating and summarising posterior samples), namely:

    num_samples - number of samples to use (Default = 1000). batch_size - data batch size (keep low enough to fit on GPU, default 2048). use_gpu - use gpu for generating samples?

  • export_slot – adata.uns slot where to export results

  • add_to_varm – posterior distribution summary to export in adata.varm ([‘means’, ‘stds’, ‘q05’, ‘q95’]).

  • use_quantiles – compute quantiles directly (True, more memory efficient) or use samples (False, default). If True, means and stds cannot be computed so are not exported and returned.

plot_QC(summary_name: str = 'means', use_n_obs: int = 1000, scale_average_detection: bool = True)[source]

Show quality control plots: 1. Reconstruction accuracy to assess if there are any issues with model training.

The plot should be roughly diagonal, strong deviations signal problems that need to be investigated. Plotting is slow because expected value of mRNA count needs to be computed from model parameters. Random observations are used to speed up computation.

  1. Estimated reference expression signatures (accounting for batch effect)

    compared to average expression in each cluster. We expect the signatures to be different from average when batch effects are present, however, when this plot is very different from a perfect diagonal, such as very low values on Y-axis, non-zero density everywhere) it indicates problems with signature estimation.

Parameters

summary_name – posterior distribution summary to use (‘means’, ‘stds’, ‘q05’, ‘q95’)