Source code for cell2location.models.base._pyro_base_loc_module

from typing import Literal, Optional

from pyro.infer.autoguide import init_to_mean
from scvi.module.base import PyroBaseModuleClass

from ._pyro_mixin import AutoGuideMixinModule, init_to_value


[docs]class Cell2locationBaseModule(PyroBaseModuleClass, AutoGuideMixinModule): r""" Module class which defines AutoGuide given model. Supports multiple model architectures. Parameters ---------- amortised boolean, use a Neural Network to approximate posterior distribution of location-specific (local) parameters? encoder_mode Use single encoder for all variables ("single"), one encoder per variable ("multiple") or a single encoder in the first step and multiple encoders in the second step ("single-multiple"). encoder_kwargs arguments for Neural Network construction (scvi.nn.FCLayers) kwargs arguments for specific model class - e.g. number of genes, values of the prior distribution """ def __init__( self, model, amortised: bool = False, encoder_mode: Literal["single", "multiple", "single-multiple"] = "single", encoder_kwargs: Optional[dict] = None, data_transform="log1p", create_autoguide_kwargs: Optional[dict] = None, **kwargs, ): super().__init__() self.hist = [] self._model = model(**kwargs) self._amortised = amortised if create_autoguide_kwargs is None: create_autoguide_kwargs = dict() self._guide = self._create_autoguide( model=self.model, amortised=self.is_amortised, encoder_kwargs=encoder_kwargs, data_transform=data_transform, encoder_mode=encoder_mode, init_loc_fn=self.init_to_value, n_cat_list=[kwargs["n_batch"]], **create_autoguide_kwargs, ) self._get_fn_args_from_batch = self._model._get_fn_args_from_batch @property def model(self): return self._model @property def guide(self): return self._guide @property def list_obs_plate_vars(self): """ Create a dictionary with: 1. "name" - the name of observation/minibatch plate; 2. "input" - indexes of model args to provide to encoder network when using amortised inference; 3. "sites" - dictionary with * keys - names of variables that belong to the observation plate (used to recognise and merge posterior samples for minibatch variables) * values - the dimensions in non-plate axis of each variable (used to construct output layer of encoder network when using amortised inference) """ return self.model.list_obs_plate_vars() @property def is_amortised(self): return self._amortised
[docs] def init_to_value(self, site): if getattr(self.model, "np_init_vals", None) is not None: init_vals = {k: getattr(self.model, f"init_val_{k}") for k in self.model.np_init_vals.keys()} else: init_vals = dict() return init_to_value(site=site, values=init_vals, init_fn=init_to_mean)