Source code for cell2location.models._cell2location_model

from typing import List, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy
from anndata import AnnData
from pyro import clear_param_store
from pyro.infer import Trace_ELBO, TraceEnum_ELBO
from pyro.nn import PyroModule
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import (
    CategoricalJointObsField,
    CategoricalObsField,
    LayerField,
    NumericalJointObsField,
    NumericalObsField,
)
from scvi.dataloaders import DataSplitter, DeviceBackedDataSplitter
from scvi.model.base import BaseModelClass, PyroSampleMixin, PyroSviTrainMixin
from scvi.model.base._pyromixin import PyroJitGuideWarmup
from scvi.train import TrainRunner
from scvi.utils import setup_anndata_dsp

from cell2location.models._cell2location_module import (
    LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel,
)
from cell2location.models.base._pyro_base_loc_module import Cell2locationBaseModule
from cell2location.models.base._pyro_mixin import (
    PltExportMixin,
    PyroAggressiveConvergence,
    PyroAggressiveTrainingPlan,
    QuantileMixin,
)
from cell2location.utils import select_slide


[docs]class Cell2location(QuantileMixin, PyroSampleMixin, PyroSviTrainMixin, PltExportMixin, BaseModelClass): r""" Cell2location model. User-end model class. See Module class for description of the model (incl. math). Parameters ---------- adata spatial AnnData object that has been registered via :func:`~scvi.data.setup_anndata`. cell_state_df pd.DataFrame with reference expression signatures for each gene (rows) in each cell type/population (columns). use_gpu Use the GPU? **model_kwargs Keyword args for :class:`~cell2location.models.LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel` Examples -------- TODO add example >>> """ def __init__( self, adata: AnnData, cell_state_df: pd.DataFrame, model_class: Optional[PyroModule] = None, detection_mean_per_sample: bool = False, detection_mean_correction: float = 1.0, **model_kwargs, ): # in case any other model was created before that shares the same parameter names. clear_param_store() if not np.all(adata.var_names == cell_state_df.index): raise ValueError("adata.var_names should match cell_state_df.index, find interecting variables/genes first") super().__init__(adata) self.mi_ = [] if model_class is None: model_class = LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel self.cell_state_df_ = cell_state_df self.n_factors_ = cell_state_df.shape[1] self.factor_names_ = cell_state_df.columns.values if not detection_mean_per_sample: # compute expected change in sensitivity (m_g in V1 or y_s in V2) sc_total = cell_state_df.sum(0).mean() sp_total = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY).sum(1).mean() self.detection_mean_ = (sp_total / model_kwargs.get("N_cells_per_location", 1)) / sc_total self.detection_mean_ = self.detection_mean_ * detection_mean_correction model_kwargs["detection_mean"] = self.detection_mean_ else: # compute expected change in sensitivity (m_g in V1 and y_s in V2) sc_total = cell_state_df.sum(0).mean() sp_total = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY).sum(1) batch = self.adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY).flatten() sp_total = np.array([sp_total[batch == b].mean() for b in range(self.summary_stats["n_batch"])]) self.detection_mean_ = (sp_total / model_kwargs.get("N_cells_per_location", 1)) / sc_total self.detection_mean_ = self.detection_mean_ * detection_mean_correction model_kwargs["detection_mean"] = self.detection_mean_.reshape((self.summary_stats["n_batch"], 1)).astype( "float32" ) detection_alpha = model_kwargs.get("detection_alpha", None) if detection_alpha is not None: if type(detection_alpha) is dict: batch_mapping = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).categorical_mapping self.detection_alpha_ = pd.Series(detection_alpha)[batch_mapping] model_kwargs["detection_alpha"] = self.detection_alpha_.values.reshape( (self.summary_stats["n_batch"], 1) ).astype("float32") self.module = Cell2locationBaseModule( model=model_class, n_obs=self.summary_stats["n_cells"], n_vars=self.summary_stats["n_vars"], n_factors=self.n_factors_, n_batch=self.summary_stats["n_batch"], cell_state_mat=self.cell_state_df_.values.astype("float32"), **model_kwargs, ) self._model_summary_string = f'cell2location model with the following params: \nn_factors: {self.n_factors_} \nn_batch: {self.summary_stats["n_batch"]} ' self.init_params_ = self._get_init_params(locals())
[docs] @classmethod @setup_anndata_dsp.dedent def setup_anndata( cls, adata: AnnData, layer: Optional[str] = None, batch_key: Optional[str] = None, labels_key: Optional[str] = None, categorical_covariate_keys: Optional[List[str]] = None, continuous_covariate_keys: Optional[List[str]] = None, **kwargs, ): """ %(summary)s. Parameters ---------- %(param_layer)s %(param_batch_key)s %(param_labels_key)s %(param_cat_cov_keys)s %(param_cont_cov_keys)s """ setup_method_args = cls._get_setup_method_args(**locals()) adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64") anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), ] adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager)
[docs] def train( self, max_epochs: int = 30000, batch_size: int = None, train_size: float = 1, lr: float = 0.002, num_particles: int = 1, scale_elbo: float = 1.0, **kwargs, ): """Train the model with useful defaults Parameters ---------- max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` train_size Size of training set in the range [0.0, 1.0]. Use all data points in training because we need to estimate cell abundance at all locations. batch_size Minibatch size to use during training. If `None`, no minibatching occurs and all data is copied to device (e.g., GPU). lr Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`). Specifying optimiser via plan_kwargs overrides this choice of lr. kwargs Other arguments to scvi.model.base.PyroSviTrainMixin().train() method """ kwargs["max_epochs"] = max_epochs kwargs["batch_size"] = batch_size kwargs["train_size"] = train_size kwargs["lr"] = lr if "plan_kwargs" not in kwargs.keys(): kwargs["plan_kwargs"] = dict() if getattr(self.module.model, "discrete_variables", None) and (len(self.module.model.discrete_variables) > 0): kwargs["plan_kwargs"]["loss_fn"] = TraceEnum_ELBO(num_particles=num_particles) else: kwargs["plan_kwargs"]["loss_fn"] = Trace_ELBO(num_particles=num_particles) if scale_elbo != 1.0: if scale_elbo is None: scale_elbo = 1.0 / (self.summary_stats["n_cells"] * self.summary_stats["n_vars"]) kwargs["plan_kwargs"]["scale_elbo"] = scale_elbo super().train(**kwargs)
[docs] def train_aggressive( self, max_epochs: Optional[int] = 1000, use_gpu: Optional[Union[str, int, bool]] = None, train_size: float = 1, validation_size: Optional[float] = None, batch_size: int = None, early_stopping: bool = False, lr: Optional[float] = None, plan_kwargs: Optional[dict] = None, **trainer_kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. If `None`, no minibatching occurs and all data is copied to device (e.g., GPU). early_stopping Perform early stopping. Additional arguments can be passed in `**kwargs`. See :class:`~scvi.train.Trainer` for further options. lr Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`). Specifying optimiser via plan_kwargs overrides this choice of lr. plan_kwargs Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. **trainer_kwargs Other keyword args for :class:`~scvi.train.Trainer`. """ if max_epochs is None: n_obs = self.adata_manager.adata.n_obs max_epochs = np.min([round((20000 / n_obs) * 1000), 1000]) plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() if lr is not None and "optim" not in plan_kwargs.keys(): plan_kwargs.update({"optim_kwargs": {"lr": lr}}) if batch_size is None: # use data splitter which moves data to GPU once data_splitter = DeviceBackedDataSplitter( self.adata_manager, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, ) else: data_splitter = DataSplitter( self.adata_manager, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, ) training_plan = PyroAggressiveTrainingPlan(pyro_module=self.module, **plan_kwargs) es = "early_stopping" trainer_kwargs[es] = early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es] if "callbacks" not in trainer_kwargs.keys(): trainer_kwargs["callbacks"] = [] trainer_kwargs["callbacks"].append(PyroJitGuideWarmup()) trainer_kwargs["callbacks"].append(PyroAggressiveConvergence()) runner = TrainRunner( self, training_plan=training_plan, data_splitter=data_splitter, max_epochs=max_epochs, use_gpu=use_gpu, **trainer_kwargs, ) res = runner() self.mi_ = self.mi_ + training_plan.mi return res
[docs] def export_posterior( self, adata, sample_kwargs: Optional[dict] = None, export_slot: str = "mod", add_to_obsm: list = ["means", "stds", "q05", "q95"], use_quantiles: bool = False, ): """ Summarise posterior distribution and export results (cell abundance) to anndata object: 1. adata.obsm: Estimated cell abundance as pd.DataFrames for each posterior distribution summary `add_to_obsm`, posterior mean, sd, 5% and 95% quantiles (['means', 'stds', 'q05', 'q95']). If export to adata.obsm fails with error, results are saved to adata.obs instead. 2. adata.uns: Posterior of all parameters, model name, date, cell type names ('factor_names'), obs and var names. Parameters ---------- adata anndata object where results should be saved sample_kwargs arguments for self.sample_posterior (generating and summarising posterior samples), namely: num_samples - number of samples to use (Default = 1000). batch_size - data batch size (keep low enough to fit on GPU, default 2048). use_gpu - use gpu for generating samples? export_slot adata.uns slot where to export results add_to_obsm posterior distribution summary to export in adata.obsm (['means', 'stds', 'q05', 'q95']). use_quantiles compute quantiles directly (True, more memory efficient) or use samples (False, default). If True, means and stds cannot be computed so are not exported and returned. Returns ------- """ sample_kwargs = sample_kwargs if isinstance(sample_kwargs, dict) else dict() # get posterior distribution summary if use_quantiles: add_to_obsm = [i for i in add_to_obsm if (i not in ["means", "stds"]) and ("q" in i)] if len(add_to_obsm) == 0: raise ValueError("No quantiles to export - please add add_to_obsm=['q05', 'q50', 'q95'].") self.samples = dict() for i in add_to_obsm: q = float(f"0.{i[1:]}") self.samples[f"post_sample_{i}"] = self.posterior_quantile(q=q, **sample_kwargs) else: # generate samples from posterior distributions for all parameters # and compute mean, 5%/95% quantiles and standard deviation self.samples = self.sample_posterior(**sample_kwargs) # export posterior distribution summary for all parameters and # annotation (model, date, var, obs and cell type names) to anndata object adata.uns[export_slot] = self._export2adata(self.samples) # add estimated cell abundance as dataframe to obsm in anndata # first convert np.arrays to pd.DataFrames with cell type and observation names # data frames contain mean, 5%/95% quantiles and standard deviation, denoted by a prefix for k in add_to_obsm: sample_df = self.sample2df_obs( self.samples, site_name="w_sf", summary_name=k, name_prefix="cell_abundance", ) try: adata.obsm[f"{k}_cell_abundance_w_sf"] = sample_df.loc[adata.obs.index, :] except ValueError: # Catching weird error with obsm: `ValueError: value.index does not match parent’s axis 1 names` adata.obs[sample_df.columns] = sample_df.loc[adata.obs.index, :] return adata
[docs] def plot_spatial_QC_across_batches(self): """QC plot: compare total RNA count with estimated total cell abundance and detection sensitivity.""" adata = self.adata # get batch key and the list of samples batch_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key samples = adata.obs[batch_key].unique() # figure out plot shape ncol = len(samples) nrow = 3 fig, axs = plt.subplots(nrow, ncol, figsize=(1 + 4 * ncol, 1 + 4 * nrow)) if ncol == 1: axs = axs.reshape((nrow, 1)) # compute total counts # find data slot x_dict = self.adata_manager.data_registry[REGISTRY_KEYS.X_KEY] if x_dict["attr_name"] == "X": use_raw = False else: use_raw = True if x_dict["attr_name"] == "layers": layer = x_dict["attr_key"] else: layer = None # get data if layer is not None: x = adata.layers[layer] else: if not use_raw: x = adata.X else: x = adata.raw.X # compute total counts per location cell_type = "total RNA counts" adata.obs[cell_type] = np.array(x.sum(1)).flatten() # figure out colour map scaling vmax = np.quantile(adata.obs[cell_type].values, 0.992) # plot, iterating across samples for i, s in enumerate(samples): sp_data_s = select_slide(adata, s, batch_key=batch_key) scanpy.pl.spatial( sp_data_s, cmap="magma", color=cell_type, size=1.3, img_key="hires", alpha_img=1, vmin=0, vmax=vmax, ax=axs[0, i], show=False, ) axs[0, i].title.set_text(cell_type + "\n" + s) cell_type = "Total cell abundance (sum_f w_sf)" adata.obs[cell_type] = adata.uns["mod"]["post_sample_means"]["w_sf"].sum(1).flatten() # figure out colour map scaling vmax = np.quantile(adata.obs[cell_type].values, 0.992) # plot, iterating across samples for i, s in enumerate(samples): sp_data_s = select_slide(adata, s, batch_key=batch_key) scanpy.pl.spatial( sp_data_s, cmap="magma", color=cell_type, size=1.3, img_key="hires", alpha_img=1, vmin=0, vmax=vmax, ax=axs[1, i], show=False, ) axs[1, i].title.set_text(cell_type + "\n" + s) cell_type = "RNA detection sensitivity (y_s)" adata.obs[cell_type] = adata.uns["mod"]["post_sample_q05"]["detection_y_s"] # figure out colour map scaling vmax = np.quantile(adata.obs[cell_type].values, 0.992) # plot, iterating across samples for i, s in enumerate(samples): sp_data_s = select_slide(adata, s, batch_key=batch_key) scanpy.pl.spatial( sp_data_s, cmap="magma", color=cell_type, size=1.3, img_key="hires", alpha_img=1, vmin=0, vmax=vmax, ax=axs[2, i], show=False, ) axs[2, i].title.set_text(cell_type + "\n" + s) fig.tight_layout(pad=0.5) return fig