Source code for cell2location.models.reference._reference_module

from typing import Optional

import numpy as np
import pandas as pd
import pyro
import pyro.distributions as dist
import torch
from pyro.nn import PyroModule
from scvi import REGISTRY_KEYS
from scvi.nn import one_hot


[docs]class RegressionBackgroundDetectionTechPyroModel(PyroModule): r""" Given cell type annotation for each cell, the corresponding reference cell type signatures :math:`g_{f,g}`, which represent the average mRNA count of each gene `g` in each cell type `f={1, .., F}`, are estimated from sc/snRNA-seq data using Negative Binomial regression, which allows to robustly combine data across technologies and batches. This model combines batches, and treats data :math:`D` as Negative Binomial distributed, given mean :math:`\mu` and overdispersion :math:`\alpha`: .. math:: D_{c,g} \sim \mathtt{NB}(alpha=\alpha_{g}, mu=\mu_{c,g}) .. math:: \mu_{c,g} = (\mu_{f,g} + s_{e,g}) * y_e * y_{t,g} Which is equivalent to: .. math:: D_{c,g} \sim \mathtt{Poisson}(\mathtt{Gamma}(\alpha_{f,g}, \alpha_{f,g} / \mu_{c,g})) Here, :math:`\mu_{f,g}` denotes average mRNA count in each cell type :math:`f` for each gene :math:`g`; :math:`y_c` denotes normalisation for each experiment :math:`e` to account for sequencing depth. :math:`y_{t,g}` denotes per gene :math:`g` detection efficiency normalisation for each technology :math:`t`. """ def __init__( self, n_obs, n_vars, n_factors, n_batch, n_extra_categoricals=None, alpha_g_phi_hyp_prior={"alpha": 9.0, "beta": 3.0}, gene_add_alpha_hyp_prior={"alpha": 9.0, "beta": 3.0}, gene_add_mean_hyp_prior={ "alpha": 1.0, "beta": 100.0, }, detection_hyp_prior={"mean_alpha": 1.0, "mean_beta": 1.0}, gene_tech_prior={"mean": 1, "alpha": 200}, init_vals: Optional[dict] = None, ): """ Parameters ---------- n_obs n_vars n_factors n_batch n_extra_categoricals alpha_g_phi_hyp_prior gene_add_alpha_hyp_prior gene_add_mean_hyp_prior detection_hyp_prior gene_tech_prior """ ############# Initialise parameters ################ super().__init__() self.n_obs = n_obs self.n_vars = n_vars self.n_factors = n_factors self.n_batch = n_batch self.n_extra_categoricals = n_extra_categoricals self.alpha_g_phi_hyp_prior = alpha_g_phi_hyp_prior self.gene_add_alpha_hyp_prior = gene_add_alpha_hyp_prior self.gene_add_mean_hyp_prior = gene_add_mean_hyp_prior self.detection_hyp_prior = detection_hyp_prior self.gene_tech_prior = gene_tech_prior if (init_vals is not None) & (type(init_vals) is dict): self.np_init_vals = init_vals for k in init_vals.keys(): self.register_buffer(f"init_val_{k}", torch.tensor(init_vals[k])) self.register_buffer( "detection_mean_hyp_prior_alpha", torch.tensor(self.detection_hyp_prior["mean_alpha"]), ) self.register_buffer( "detection_mean_hyp_prior_beta", torch.tensor(self.detection_hyp_prior["mean_beta"]), ) self.register_buffer( "gene_tech_prior_alpha", torch.tensor(self.gene_tech_prior["alpha"]), ) self.register_buffer( "gene_tech_prior_beta", torch.tensor(self.gene_tech_prior["alpha"] / self.gene_tech_prior["mean"]), ) self.register_buffer( "alpha_g_phi_hyp_prior_alpha", torch.tensor(self.alpha_g_phi_hyp_prior["alpha"]), ) self.register_buffer( "alpha_g_phi_hyp_prior_beta", torch.tensor(self.alpha_g_phi_hyp_prior["beta"]), ) self.register_buffer( "gene_add_alpha_hyp_prior_alpha", torch.tensor(self.gene_add_alpha_hyp_prior["alpha"]), ) self.register_buffer( "gene_add_alpha_hyp_prior_beta", torch.tensor(self.gene_add_alpha_hyp_prior["beta"]), ) self.register_buffer( "gene_add_mean_hyp_prior_alpha", torch.tensor(self.gene_add_mean_hyp_prior["alpha"]), ) self.register_buffer( "gene_add_mean_hyp_prior_beta", torch.tensor(self.gene_add_mean_hyp_prior["beta"]), ) self.register_buffer("ones", torch.ones((1, 1))) self.register_buffer("eps", torch.tensor(1e-8)) ############# Define the model ################ @staticmethod def _get_fn_args_from_batch_no_cat(tensor_dict): x_data = tensor_dict[REGISTRY_KEYS.X_KEY] ind_x = tensor_dict["ind_x"].long().squeeze() batch_index = tensor_dict[REGISTRY_KEYS.BATCH_KEY] label_index = tensor_dict[REGISTRY_KEYS.LABELS_KEY] return (x_data, ind_x, batch_index, label_index, label_index), {} @staticmethod def _get_fn_args_from_batch_cat(tensor_dict): x_data = tensor_dict[REGISTRY_KEYS.X_KEY] ind_x = tensor_dict["ind_x"].long().squeeze() batch_index = tensor_dict[REGISTRY_KEYS.BATCH_KEY] label_index = tensor_dict[REGISTRY_KEYS.LABELS_KEY] extra_categoricals = tensor_dict[REGISTRY_KEYS.CAT_COVS_KEY] return (x_data, ind_x, batch_index, label_index, extra_categoricals), {} @property def _get_fn_args_from_batch(self): if self.n_extra_categoricals is not None: return self._get_fn_args_from_batch_cat else: return self._get_fn_args_from_batch_no_cat
[docs] def create_plates(self, x_data, idx, batch_index, label_index, extra_categoricals): return pyro.plate("obs_plate", size=self.n_obs, dim=-2, subsample=idx)
[docs] def list_obs_plate_vars(self): """Create a dictionary with the name of observation/minibatch plate, indexes of model args to provide to encoder, variable names that belong to the observation plate and the number of dimensions in non-plate axis of each variable""" return { "name": "obs_plate", "input": [], # expression data + (optional) batch index "input_transform": [], # how to transform input data before passing to NN "sites": {}, }
[docs] def forward(self, x_data, idx, batch_index, label_index, extra_categoricals): obs2sample = one_hot(batch_index, self.n_batch) obs2label = one_hot(label_index, self.n_factors) if self.n_extra_categoricals is not None: obs2extra_categoricals = torch.cat( [ one_hot( extra_categoricals[:, i].view((extra_categoricals.shape[0], 1)), n_cat, ) for i, n_cat in enumerate(self.n_extra_categoricals) ], dim=1, ) obs_plate = self.create_plates(x_data, idx, batch_index, label_index, extra_categoricals) # =====================Per-cluster average mRNA count ======================= # # \mu_{f,g} per_cluster_mu_fg = pyro.sample( "per_cluster_mu_fg", dist.Gamma(self.ones, self.ones).expand([self.n_factors, self.n_vars]).to_event(2), ) # =====================Gene-specific multiplicative component ======================= # # `y_{t, g}` per gene multiplicative effect that explains the difference # in sensitivity between genes in each technology or covariate effect if self.n_extra_categoricals is not None: detection_tech_gene_tg = pyro.sample( "detection_tech_gene_tg", dist.Gamma( self.ones * self.gene_tech_prior_alpha, self.ones * self.gene_tech_prior_beta, ) .expand([np.sum(self.n_extra_categoricals), self.n_vars]) .to_event(2), ) # =====================Cell-specific detection efficiency ======================= # # y_c with hierarchical mean prior detection_mean_y_e = pyro.sample( "detection_mean_y_e", dist.Gamma( self.ones * self.detection_mean_hyp_prior_alpha, self.ones * self.detection_mean_hyp_prior_beta, ) .expand([self.n_batch, 1]) .to_event(2), ) detection_y_c = obs2sample @ detection_mean_y_e # (self.n_obs, 1) # =====================Gene-specific additive component ======================= # # s_{e,g} accounting for background, free-floating RNA s_g_gene_add_alpha_hyp = pyro.sample( "s_g_gene_add_alpha_hyp", dist.Gamma(self.ones * self.gene_add_alpha_hyp_prior_alpha, self.ones * self.gene_add_alpha_hyp_prior_beta), ) s_g_gene_add_mean = pyro.sample( "s_g_gene_add_mean", dist.Gamma( self.gene_add_mean_hyp_prior_alpha, self.gene_add_mean_hyp_prior_beta, ) .expand([self.n_batch, 1]) .to_event(2), ) # (self.n_batch) s_g_gene_add_alpha_e_inv = pyro.sample( "s_g_gene_add_alpha_e_inv", dist.Exponential(s_g_gene_add_alpha_hyp).expand([self.n_batch, 1]).to_event(2), ) # (self.n_batch) s_g_gene_add_alpha_e = self.ones / s_g_gene_add_alpha_e_inv.pow(2) s_g_gene_add = pyro.sample( "s_g_gene_add", dist.Gamma(s_g_gene_add_alpha_e, s_g_gene_add_alpha_e / s_g_gene_add_mean) .expand([self.n_batch, self.n_vars]) .to_event(2), ) # (self.n_batch, n_vars) # =====================Gene-specific overdispersion ======================= # alpha_g_phi_hyp = pyro.sample( "alpha_g_phi_hyp", dist.Gamma(self.ones * self.alpha_g_phi_hyp_prior_alpha, self.ones * self.alpha_g_phi_hyp_prior_beta), ) alpha_g_inverse = pyro.sample( "alpha_g_inverse", dist.Exponential(alpha_g_phi_hyp).expand([1, self.n_vars]).to_event(2), ) # (self.n_batch or 1, self.n_vars) # =====================Expected expression ======================= # # overdispersion alpha = self.ones / alpha_g_inverse.pow(2) # biological expression mu = ( obs2label @ per_cluster_mu_fg + obs2sample @ s_g_gene_add # contaminating RNA ) * detection_y_c # cell-specific normalisation if self.n_extra_categoricals is not None: # gene-specific normalisation for covatiates mu = mu * (obs2extra_categoricals @ detection_tech_gene_tg) # total_count, logits = _convert_mean_disp_to_counts_logits( # mu, alpha, eps=self.eps # ) # =====================DATA likelihood ======================= # # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial with obs_plate: pyro.sample( "data_target", dist.GammaPoisson(concentration=alpha, rate=alpha / mu), # dist.NegativeBinomial(total_count=total_count, logits=logits), obs=x_data, )
# =====================Other functions======================= #
[docs] def compute_expected(self, samples, adata_manager, ind_x=None): r"""Compute expected expression of each gene in each cell. Useful for evaluating how well the model learned expression pattern of all genes in the data. Parameters ---------- samples dictionary with values of the posterior adata registered anndata ind_x indices of cells to use (to reduce data size) """ if ind_x is None: ind_x = np.arange(adata_manager.adata.n_obs).astype(int) else: ind_x = ind_x.astype(int) obs2sample = adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY) obs2sample = pd.get_dummies(obs2sample.flatten()).values[ind_x, :].astype("float32") obs2label = adata_manager.get_from_registry(REGISTRY_KEYS.LABELS_KEY) obs2label = pd.get_dummies(obs2label.flatten()).values[ind_x, :].astype("float32") if self.n_extra_categoricals is not None: extra_categoricals = adata_manager.get_from_registry(REGISTRY_KEYS.CAT_COVS_KEY) obs2extra_categoricals = np.concatenate( [ pd.get_dummies(extra_categoricals.iloc[ind_x, i]).astype("float32") for i, n_cat in enumerate(self.n_extra_categoricals) ], axis=1, ) alpha = 1 / np.power(samples["alpha_g_inverse"], 2) mu = (np.dot(obs2label, samples["per_cluster_mu_fg"]) + np.dot(obs2sample, samples["s_g_gene_add"])) * np.dot( obs2sample, samples["detection_mean_y_e"] ) # samples["detection_y_c"][ind_x, :] if self.n_extra_categoricals is not None: mu = mu * np.dot(obs2extra_categoricals, samples["detection_tech_gene_tg"]) return {"mu": mu, "alpha": alpha}
[docs] def compute_expected_subset(self, samples, adata_manager, fact_ind, cell_ind): r"""Compute expected expression of each gene in each cell that comes from a subset of factors (cell types) or cells. Useful for evaluating how well the model learned expression pattern of all genes in the data. Parameters ---------- samples dictionary with values of the posterior adata registered anndata fact_ind indices of factors/cell types to use cell_ind indices of cells to use """ obs2sample = adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY) obs2sample = pd.get_dummies(obs2sample.flatten()) obs2label = adata_manager.get_from_registry(REGISTRY_KEYS.LABELS_KEY) obs2label = pd.get_dummies(obs2label.flatten()) if self.n_extra_categoricals is not None: extra_categoricals = adata_manager.get_from_registry(REGISTRY_KEYS.CAT_COVS_KEY) obs2extra_categoricals = np.concatenate( [pd.get_dummies(extra_categoricals.iloc[:, i]) for i, n_cat in enumerate(self.n_extra_categoricals)], axis=1, ) alpha = 1 / np.power(samples["alpha_g_inverse"], 2) mu = ( np.dot(obs2label[cell_ind, fact_ind], samples["per_cluster_mu_fg"][fact_ind, :]) + np.dot(obs2sample[cell_ind, :], samples["s_g_gene_add"]) ) * np.dot( obs2sample, samples["detection_mean_y_e"] ) # samples["detection_y_c"] if self.n_extra_categoricals is not None: mu = mu * np.dot(obs2extra_categoricals[cell_ind, :], samples["detection_tech_gene_tg"]) return {"mu": mu, "alpha": alpha}
[docs] def normalise(self, samples, adata_manager, adata): r"""Normalise expression data by estimated technical variables. Parameters ---------- samples dictionary with values of the posterior adata registered anndata """ obs2sample = adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY) obs2sample = pd.get_dummies(obs2sample.flatten()) if self.n_extra_categoricals is not None: extra_categoricals = adata_manager.get_from_registry(REGISTRY_KEYS.CAT_COVS_KEY) obs2extra_categoricals = np.concatenate( [pd.get_dummies(extra_categoricals.iloc[:, i]) for i, n_cat in enumerate(self.n_extra_categoricals)], axis=1, ) # get counts matrix corrected = adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) # normalise per-sample scaling corrected = corrected / np.dot(obs2sample, samples["detection_mean_y_e"]) # normalise per gene effects if self.n_extra_categoricals is not None: corrected = corrected / np.dot(obs2extra_categoricals, samples["detection_tech_gene_tg"]) # remove additive sample effects corrected = corrected - np.dot(obs2sample, samples["s_g_gene_add"]) # set minimum value to 0 for each gene (a hack to avoid negative values) corrected = corrected - corrected.min() return corrected