Source code for cell2location.distributions.AutoNormalEncoder

from contextlib import ExitStack  # python 3
from copy import deepcopy
from typing import Literal

import numpy as np
import pyro
import pyro.distributions as dist
import torch
from pyro.distributions.transforms import SoftplusTransform
from pyro.distributions.util import sum_rightmost
from pyro.infer.autoguide import AutoGuide
from pyro.infer.autoguide import AutoGuideList as PyroAutoGuideList
from pyro.infer.autoguide.guides import deep_getattr, deep_setattr
from pyro.infer.autoguide.utils import helpful_support_errors
from pyro.nn import PyroModule, PyroParam
from pyro.nn.module import to_pyro_module_
from scvi.nn import FCLayers
from torch.distributions import biject_to


class FCLayersPyro(FCLayers, PyroModule):
    pass


class AutoGuideList(PyroAutoGuideList):
    def quantiles(self, quantiles, *args, **kwargs):
        """
        Returns the posterior quantile values of each latent variable.

        Parameters
        ----------
        quantiles
            A list of requested quantiles between 0 and 1.

        Returns
        -------
        A dict mapping sample site name to quantiles tensor.
        """

        result = {}
        for part in self:
            result.update(part.quantiles(quantiles, *args, **kwargs))
        return result


[docs]class AutoNormalEncoder(AutoGuide): """ AutoNormal posterior approximation for amortised inference, where mean and sd of the posterior distributions are approximated using a neural network: mean, sd = encoderNN(input data). The class supports single encoder for all parameters as well as one encoder per parameter. The output of encoder network is treated as a hidden layer, mean and sd are a linear function of hidden layer nodes, sd is transformed to positive scale using softplus. Data is log-transformed on input. This class requires `amortised_plate_sites` dictionary with details about amortised variables (see below). Guide will have the same call signature as the model, so any argument to the model can be used for encoding as annotated in `amortised_plate_sites`, but it does not have to be the same as observed data in the model. """ def __init__( self, model, amortised_plate_sites: dict, n_in: int, n_hidden: dict = None, init_param=0, init_param_scale: float = 1 / 50, scales_offset: float = -2, encoder_class=FCLayersPyro, encoder_kwargs=None, multi_encoder_kwargs=None, encoder_instance: torch.nn.Module = None, create_plates=None, encoder_mode: Literal["single", "multiple", "single-multiple"] = "single", ): """ Parameters ---------- model Pyro model amortised_plate_sites Dictionary with amortised plate details: the name of observation/minibatch plate, indexes of model args to provide to encoder, variable names that belong to the observation plate and the number of dimensions in non-plate axis of each variable - such as: { "name": "obs_plate", "input": [0], # expression data + (optional) batch index ([0, 2]) "input_transform": [torch.log1p], # how to transform input data before passing to NN "sites": { "n_s_cells_per_location": 1, "y_s_groups_per_location": 1, "z_sr_groups_factors": self.n_groups, "w_sf": self.n_factors, "l_s_add": 1, } } n_in Number of input dimensions (for encoder_class). n_hidden Number of hidden nodes in each layer, one of 3 options: 1. Integer denoting the number of hidden nodes 2. Dictionary with {"single": 200, "multiple": 200} denoting the number of hidden nodes for each `encoder_mode` (See below) 3. Allowing different number of hidden nodes for each model site. Dictionary with the number of hidden nodes for single encode mode and each model site: { "single": 200 "n_s_cells_per_location": 5, "y_s_groups_per_location": 5, "z_sr_groups_factors": 128, "w_sf": 128, "l_s_add": 5, } init_param Not implemented yet - initial values for amortised variables. init_param_scale How to scale/normalise initial values for weights converting hidden layers to mean and sd. encoder_class Class for defining encoder network. encoder_kwargs Keyword arguments for encoder_class. multi_encoder_kwargs Optional separate keyword arguments for encoder_class, useful when encoder_mode == "single-multiple". encoder_instance Encoder network instance, overrides class input and the input instance is copied with deepcopy. create_plates Function for creating plates encoder_mode Use single encoder for all variables ("single"), one encoder per variable ("multiple") or a single encoder in the first step and multiple encoders in the second step ("single-multiple"). """ super().__init__(model, create_plates=create_plates) self.amortised_plate_sites = amortised_plate_sites self.encoder_mode = encoder_mode self.scales_offset = scales_offset self.softplus = SoftplusTransform() if n_hidden is None: n_hidden = {"single": 200, "multiple": 200} else: if isinstance(n_hidden, int): n_hidden = {"single": n_hidden, "multiple": n_hidden} elif not isinstance(n_hidden, dict): raise ValueError("n_hidden must be either in or dict") encoder_kwargs = deepcopy(encoder_kwargs) if isinstance(encoder_kwargs, dict) else dict() encoder_kwargs["n_hidden"] = n_hidden["single"] self.encoder_kwargs = encoder_kwargs if multi_encoder_kwargs is None: multi_encoder_kwargs = deepcopy(encoder_kwargs) self.multi_encoder_kwargs = multi_encoder_kwargs if "multiple" in n_hidden.keys(): self.multi_encoder_kwargs["n_hidden"] = n_hidden["multiple"] self.single_n_in = n_in self.multiple_n_in = n_in self.n_out = ( np.sum([np.sum(amortised_plate_sites["sites"][k]) for k in amortised_plate_sites["sites"].keys()]) * 2 ) self.n_hidden = n_hidden self.encoder_class = encoder_class self.encoder_instance = encoder_instance if "single" in self.encoder_mode: # create a single encoder NN if encoder_instance is not None: self.one_encoder = deepcopy(encoder_instance) # convert to pyro module to_pyro_module_(self.one_encoder) else: self.one_encoder = encoder_class( n_in=self.single_n_in, n_out=self.n_hidden["single"], **self.encoder_kwargs ) if "multiple" in self.encoder_mode: self.multiple_n_in = self.n_hidden["single"] self.init_param_scale = init_param_scale def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) self._event_dims = {} self._cond_indep_stacks = {} self.hidden2locs = PyroModule() self.hidden2scales = PyroModule() if "multiple" in self.encoder_mode: # create module for collecting multiple encoder NN self.multiple_encoders = PyroModule() # Initialize guide params for name, site in self.prototype_trace.iter_stochastic_nodes(): # Collect unconstrained event_dims, which may differ from constrained event_dims. with helpful_support_errors(site): init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach() event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim() self._event_dims[name] = event_dim # Collect independence contexts. self._cond_indep_stacks[name] = site["cond_indep_stack"] # determine the number of hidden layers if "multiple" in self.encoder_mode: if "multiple" in self.n_hidden.keys(): n_hidden = self.n_hidden["multiple"] else: n_hidden = self.n_hidden[name] elif "single" in self.encoder_mode: n_hidden = self.n_hidden["single"] # add linear layer for locs and scales param_dim = (n_hidden, self.amortised_plate_sites["sites"][name]) init_param = np.random.normal( np.zeros(param_dim), (np.ones(param_dim) * self.init_param_scale) / np.sqrt(n_hidden), ).astype("float32") deep_setattr( self.hidden2locs, name, PyroParam(torch.tensor(init_param, device=site["value"].device, requires_grad=True)), ) init_param = np.random.normal( np.zeros(param_dim), (np.ones(param_dim) * self.init_param_scale) / np.sqrt(n_hidden), ).astype("float32") deep_setattr( self.hidden2scales, name, PyroParam(torch.tensor(init_param, device=site["value"].device, requires_grad=True)), ) if "multiple" in self.encoder_mode: # create multiple encoders if self.encoder_instance is not None: # copy instances encoder_ = deepcopy(self.encoder_instance).to(site["value"].device) # convert to pyro module to_pyro_module_(encoder_) deep_setattr( self.multiple_encoders, name, encoder_, ) else: # create instances deep_setattr( self.multiple_encoders, name, self.encoder_class(n_in=self.multiple_n_in, n_out=n_hidden, **self.multi_encoder_kwargs).to( site["value"].device ), ) def _get_loc_and_scale(self, name, encoded_hidden): """ Get mean (loc) and sd (scale) of the posterior distribution, as a linear function of encoder hidden layer. Parameters ---------- name variable name encoded_hidden tensor when `encoder_mode == "single"` and dictionary of tensors for each site when `encoder_mode == "multiple"` """ linear_locs = deep_getattr(self.hidden2locs, name) linear_scales = deep_getattr(self.hidden2scales, name) if "multiple" in self.encoder_mode: # when using multiple encoders extract hidden layer for this parameter encoded_hidden = encoded_hidden[name] locs = encoded_hidden @ linear_locs scales = self.softplus((encoded_hidden @ linear_scales) - self.scales_offset) return locs, scales
[docs] def encode(self, *args, **kwargs): """ Apply encoder network to input data to obtain hidden layer encoding. Parameters ---------- args Pyro model args kwargs Pyro model kwargs """ in_names = self.amortised_plate_sites["input"] x_in = [kwargs[i] if i in kwargs.keys() else args[i] for i in in_names] # apply data transform before passing to NN in_transforms = self.amortised_plate_sites["input_transform"] x_in = [in_transforms[i](x) for i, x in enumerate(x_in)] if "single" in self.encoder_mode: # encode with a single encoder res = self.one_encoder(*x_in) if "multiple" in self.encoder_mode: # when there is a second layer of multiple encoders fetch encoders and encode data x_in[0] = res res = { name: deep_getattr(self.multiple_encoders, name)(*x_in) for name, site in self.prototype_trace.iter_stochastic_nodes() } else: # when there are multiple encoders fetch encoders and encode data res = { name: deep_getattr(self.multiple_encoders, name)(*x_in) for name, site in self.prototype_trace.iter_stochastic_nodes() } return res
[docs] def forward(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. .. note:: This method is used internally by :class:`~torch.nn.Module`. Users should instead use :meth:`~torch.nn.Module.__call__`. :return: A dict mapping sample site name to sampled value. :rtype: dict """ # if we've never run the model before, do so now so we can inspect the model structure if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) encoded_hidden = self.encode(*args, **kwargs) plates = self._create_plates(*args, **kwargs) result = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): transform = biject_to(site["fn"].support) with ExitStack() as stack: for frame in site["cond_indep_stack"]: if frame.vectorized: stack.enter_context(plates[frame.name]) site_loc, site_scale = self._get_loc_and_scale(name, encoded_hidden) unconstrained_latent = pyro.sample( name + "_unconstrained", dist.Normal( site_loc, site_scale, ).to_event(self._event_dims[name]), infer={"is_auxiliary": True}, ) value = transform(unconstrained_latent) if pyro.poutine.get_mask() is False: log_density = 0.0 else: log_density = transform.inv.log_abs_det_jacobian( value, unconstrained_latent, ) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + site["fn"].event_dim, ) delta_dist = dist.Delta( value, log_density=log_density, event_dim=site["fn"].event_dim, ) result[name] = pyro.sample(name, delta_dist) return result
[docs] @torch.no_grad() def median(self, *args, **kwargs): """ Returns the posterior median value of each latent variable. :return: A dict mapping sample site name to median tensor. :rtype: dict """ encoded_latent = self.encode(*args, **kwargs) medians = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): site_loc, _ = self._get_loc_and_scale(name, encoded_latent) median = biject_to(site["fn"].support)(site_loc) if median is site_loc: median = median.clone() medians[name] = median return medians
[docs] @torch.no_grad() def quantiles(self, quantiles, *args, **kwargs): """ Returns posterior quantiles each latent variable. Example:: print(guide.quantiles([0.05, 0.5, 0.95])) :param quantiles: A list of requested quantiles between 0 and 1. :type quantiles: torch.Tensor or list :return: A dict mapping sample site name to a list of quantile values. :rtype: dict """ encoded_latent = self.encode(*args, **kwargs) results = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): site_loc, site_scale = self._get_loc_and_scale(name, encoded_latent) site_quantiles = torch.tensor(quantiles, dtype=site_loc.dtype, device=site_loc.device) site_quantiles_values = dist.Normal(site_loc, site_scale).icdf(site_quantiles) constrained_site_quantiles = biject_to(site["fn"].support)(site_quantiles_values) results[name] = constrained_site_quantiles return results