Statistical distribution classes¶
AutoNormalEncoder class for automatic generation of amortised Variational approximation¶
-
class
cell2location.distributions.AutoNormalEncoder.
AutoNormalEncoder
(model, amortised_plate_sites: dict, n_in: int, n_hidden: dict = None, init_param=0, init_param_scale: float = 0.02, scales_offset: float = -2, encoder_class=<class 'cell2location.distributions.AutoNormalEncoder.FCLayersPyro'>, encoder_kwargs=None, multi_encoder_kwargs=None, encoder_instance: torch.nn.modules.module.Module = None, create_plates=None, encoder_mode: Literal[single, multiple, single-multiple] = 'single')[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuide
AutoNormal posterior approximation for amortised inference, where mean and sd of the posterior distributions are approximated using a neural network: mean, sd = encoderNN(input data).
The class supports single encoder for all parameters as well as one encoder per parameter. The output of encoder network is treated as a hidden layer, mean and sd are a linear function of hidden layer nodes, sd is transformed to positive scale using softplus. Data is log-transformed on input.
This class requires amortised_plate_sites dictionary with details about amortised variables (see below).
Guide will have the same call signature as the model, so any argument to the model can be used for encoding as annotated in amortised_plate_sites, but it does not have to be the same as observed data in the model.
-
encode
(*args, **kwargs)[source]¶ Apply encoder network to input data to obtain hidden layer encoding.
- Parameters
args – Pyro model args
kwargs – Pyro model kwargs
-
forward
(*args, **kwargs)[source]¶ An automatic guide with the same
*args, **kwargs
as the basemodel
.Note
This method is used internally by
Module
. Users should instead use__call__()
.- Returns
A dict mapping sample site name to sampled value.
- Return type
dict
-
median
(*args, **kwargs)[source]¶ Returns the posterior median value of each latent variable.
- Returns
A dict mapping sample site name to median tensor.
- Return type
dict
-
quantiles
(quantiles, *args, **kwargs)[source]¶ Returns posterior quantiles each latent variable. Example:
print(guide.quantiles([0.05, 0.5, 0.95]))
- Parameters
quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1.
- Returns
A dict mapping sample site name to a list of quantile values.
- Return type
dict
-