Statistical distribution classes

AutoNormalEncoder class for automatic generation of amortised Variational approximation

class cell2location.distributions.AutoNormalEncoder.AutoNormalEncoder(model, amortised_plate_sites: dict, n_in: int, n_hidden: dict = None, init_param=0, init_param_scale: float = 0.02, scales_offset: float = -2, encoder_class=<class 'cell2location.distributions.AutoNormalEncoder.FCLayersPyro'>, encoder_kwargs=None, multi_encoder_kwargs=None, encoder_instance: torch.nn.modules.module.Module = None, create_plates=None, encoder_mode: Literal[single, multiple, single-multiple] = 'single')[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

AutoNormal posterior approximation for amortised inference, where mean and sd of the posterior distributions are approximated using a neural network: mean, sd = encoderNN(input data).

The class supports single encoder for all parameters as well as one encoder per parameter. The output of encoder network is treated as a hidden layer, mean and sd are a linear function of hidden layer nodes, sd is transformed to positive scale using softplus. Data is log-transformed on input.

This class requires amortised_plate_sites dictionary with details about amortised variables (see below).

Guide will have the same call signature as the model, so any argument to the model can be used for encoding as annotated in amortised_plate_sites, but it does not have to be the same as observed data in the model.

encode(*args, **kwargs)[source]

Apply encoder network to input data to obtain hidden layer encoding.

Parameters
  • args – Pyro model args

  • kwargs – Pyro model kwargs

forward(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Note

This method is used internally by Module. Users should instead use __call__().

Returns

A dict mapping sample site name to sampled value.

Return type

dict

median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns

A dict mapping sample site name to median tensor.

Return type

dict

quantiles(quantiles, *args, **kwargs)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters

quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1.

Returns

A dict mapping sample site name to a list of quantile values.

Return type

dict

Other module contents