Source code for cell2location.models.base._pyro_mixin

import gc
import logging
from datetime import date
from functools import partial
from typing import Optional, Union

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyro
import pytorch_lightning as pl
import torch
from pyro import poutine
from pyro.infer.autoguide import AutoNormal, init_to_feasible, init_to_mean
from pytorch_lightning.callbacks import Callback
from scipy.sparse import issparse
from scvi import REGISTRY_KEYS
from scvi.dataloaders import AnnDataLoader
from scvi.model._utils import parse_use_gpu_arg
from scvi.module.base import PyroBaseModuleClass
from scvi.train import PyroTrainingPlan as PyroTrainingPlan_scvi

from ...distributions.AutoAmortisedNormalMessenger import (
    AutoAmortisedHierarchicalNormalMessenger,
)

logger = logging.getLogger(__name__)


[docs]def init_to_value(site=None, values={}, init_fn=init_to_mean): if site is None: return partial(init_to_value, values=values) if site["name"] in values: return values[site["name"]] else: return init_fn(site)
[docs]class AutoGuideMixinModule: """ This mixin class provides methods for: - initialising standard AutoNormal guides - initialising amortised guides (AutoNormalEncoder) - initialising amortised guides with special additional inputs """ def _create_autoguide( self, model, amortised, encoder_kwargs, data_transform, encoder_mode, init_loc_fn=init_to_mean(fallback=init_to_feasible), n_cat_list: list = [], encoder_instance=None, guide_class=AutoNormal, guide_kwargs: Optional[dict] = None, ): if guide_kwargs is None: guide_kwargs = dict() if not amortised: if getattr(model, "discrete_variables", None) is not None: model = poutine.block(model, hide=model.discrete_variables) if issubclass(guide_class, poutine.messenger.Messenger): # messenger guides don't need create_plates function _guide = guide_class( model, init_loc_fn=init_loc_fn, **guide_kwargs, ) else: _guide = guide_class( model, init_loc_fn=init_loc_fn, **guide_kwargs, create_plates=self.model.create_plates, ) else: encoder_kwargs = encoder_kwargs if isinstance(encoder_kwargs, dict) else dict() n_hidden = encoder_kwargs["n_hidden"] if "n_hidden" in encoder_kwargs.keys() else 200 if data_transform is None: pass elif isinstance(data_transform, np.ndarray): # add extra info about gene clusters as input to NN self.register_buffer("gene_clusters", torch.tensor(data_transform.astype("float32"))) n_in = model.n_vars + data_transform.shape[1] data_transform = self._data_transform_clusters() elif data_transform == "log1p": # use simple log1p transform data_transform = torch.log1p n_in = self.model.n_vars elif ( isinstance(data_transform, dict) and "var_std" in list(data_transform.keys()) and "var_mean" in list(data_transform.keys()) ): # use data transform by scaling n_in = model.n_vars self.register_buffer( "var_mean", torch.tensor(data_transform["var_mean"].astype("float32").reshape((1, n_in))), ) self.register_buffer( "var_std", torch.tensor(data_transform["var_std"].astype("float32").reshape((1, n_in))), ) data_transform = self._data_transform_scale() else: # use custom data transform data_transform = data_transform n_in = model.n_vars amortised_vars = model.list_obs_plate_vars() if len(amortised_vars["input"]) >= 2: encoder_kwargs["n_cat_list"] = n_cat_list if data_transform is not None: amortised_vars["input_transform"][0] = data_transform if "n_in" in amortised_vars.keys(): n_in = amortised_vars["n_in"] if getattr(model, "discrete_variables", None) is not None: model = poutine.block(model, hide=model.discrete_variables) _guide = AutoAmortisedHierarchicalNormalMessenger( model, amortised_plate_sites=amortised_vars, n_in=n_in, n_hidden=n_hidden, encoder_kwargs=encoder_kwargs, encoder_mode=encoder_mode, encoder_instance=encoder_instance, init_loc_fn=init_loc_fn, **guide_kwargs, ) return _guide def _data_transform_clusters(self): def _data_transform(x): return torch.log1p(torch.cat([x, x @ self.gene_clusters], dim=1)) return _data_transform def _data_transform_scale(self): def _data_transform(x): # return (x - self.var_mean) / self.var_std return x / self.var_std return _data_transform
[docs]class QuantileMixin: """ This mixin class provides methods for: - computing median and quantiles of the posterior distribution using both direct and amortised inference """ def _optim_param( self, lr: float = 0.01, autoencoding_lr: float = None, clip_norm: float = 200, module_names: list = ["encoder", "hidden2locs", "hidden2scales"], ): # TODO implement custom training method that can use this function. # create function which fetches different lr for autoencoding guide def optim_param(module_name, param_name): # detect variables in autoencoding guide if autoencoding_lr is not None and np.any([n in module_name + "." + param_name for n in module_names]): return { "lr": autoencoding_lr, # limit the gradient step from becoming too large "clip_norm": clip_norm, } else: return { "lr": lr, # limit the gradient step from becoming too large "clip_norm": clip_norm, } return optim_param @torch.no_grad() def _posterior_quantile_minibatch( self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = None, use_median: bool = True, exclude_vars: list = None, data_loader_indices=None, ): """ Compute median of the posterior distribution of each parameter, separating local (minibatch) variable and global variables, which is necessary when performing amortised inference. Note for developers: requires model class method which lists observation/minibatch plate variables (self.module.model.list_obs_plate_vars()). Parameters ---------- q quantile to compute batch_size number of observations per batch use_gpu Bool, use gpu? use_median Bool, when q=0.5 use median rather than quantile method of the guide Returns ------- dictionary {variable_name: posterior quantile} """ _, _, device = parse_use_gpu_arg(use_gpu) self.module.eval() train_dl = AnnDataLoader(self.adata_manager, shuffle=False, batch_size=batch_size, indices=data_loader_indices) # sample local parameters i = 0 for tensor_dict in train_dl: args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) if i == 0: # find plate sites obs_plate_sites = self._get_obs_plate_sites(args, kwargs, return_observed=True) if len(obs_plate_sites) == 0: # if no local variables - don't sample break # find plate dimension obs_plate_dim = list(obs_plate_sites.values())[0] if use_median and q == 0.5: means = self.module.guide.median(*args, **kwargs) else: means = self.module.guide.quantiles([q], *args, **kwargs) means = { k: means[k].cpu().numpy() for k in means.keys() if (k in obs_plate_sites) and (k not in exclude_vars) } else: if use_median and q == 0.5: means_ = self.module.guide.median(*args, **kwargs) else: means_ = self.module.guide.quantiles([q], *args, **kwargs) means_ = { k: means_[k].cpu().numpy() for k in means_.keys() if (k in obs_plate_sites) and (k not in exclude_vars) } means = {k: np.concatenate([means[k], means_[k]], axis=obs_plate_dim) for k in means.keys()} i += 1 # sample global parameters tensor_dict = next(iter(train_dl)) args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) if use_median and q == 0.5: global_means = self.module.guide.median(*args, **kwargs) else: global_means = self.module.guide.quantiles([q], *args, **kwargs) global_means = { k: global_means[k].cpu().numpy() for k in global_means.keys() if (k not in obs_plate_sites) and (k not in exclude_vars) } for k in global_means.keys(): means[k] = global_means[k] # quantile returns tensors with 0th dimension = 1 if not (use_median and q == 0.5) and ( not isinstance(self.module.guide, AutoAmortisedHierarchicalNormalMessenger) ): means = {k: means[k].squeeze(0) for k in means.keys()} self.module.to(device) return means @torch.no_grad() def _posterior_quantile( self, q: float = 0.5, batch_size: int = None, use_gpu: bool = None, use_median: bool = True, exclude_vars: list = None, data_loader_indices=None, ): """ Compute median of the posterior distribution of each parameter pyro models trained without amortised inference. Parameters ---------- q Quantile to compute use_gpu Bool, use gpu? use_median Bool, when q=0.5 use median rather than quantile method of the guide Returns ------- dictionary {variable_name: posterior quantile} """ self.module.eval() _, _, device = parse_use_gpu_arg(use_gpu) if batch_size is None: batch_size = self.adata_manager.adata.n_obs train_dl = AnnDataLoader(self.adata_manager, shuffle=False, batch_size=batch_size, indices=data_loader_indices) # sample global parameters tensor_dict = next(iter(train_dl)) args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) if use_median and q == 0.5: means = self.module.guide.median(*args, **kwargs) else: means = self.module.guide.quantiles([q], *args, **kwargs) means = {k: means[k].cpu().detach().numpy() for k in means.keys() if k not in exclude_vars} # quantile returns tensors with 0th dimension = 1 if not (use_median and q == 0.5) and ( not isinstance(self.module.guide, AutoAmortisedHierarchicalNormalMessenger) ): means = {k: means[k].squeeze(0) for k in means.keys()} return means
[docs] def posterior_quantile(self, exclude_vars: list = None, batch_size: int = None, **kwargs): """ Compute median of the posterior distribution of each parameter. Parameters ---------- q Quantile to compute use_gpu Bool, use gpu? use_median Bool, when q=0.5 use median rather than quantile method of the guide Returns ------- """ if exclude_vars is None: exclude_vars = [] if kwargs is None: kwargs = dict() if isinstance(self.module.guide, AutoNormal): # median/quantiles in AutoNormal does not require minibatches batch_size = None if batch_size is not None: return self._posterior_quantile_minibatch(exclude_vars=exclude_vars, batch_size=batch_size, **kwargs) else: return self._posterior_quantile(exclude_vars=exclude_vars, batch_size=batch_size, **kwargs)
[docs]class PltExportMixin: r""" This mixing class provides methods for common plotting tasks and data export. """
[docs] @staticmethod def plot_posterior_mu_vs_data(mu, data): r"""Plot expected value of the model (e.g. mean of NB distribution) vs observed data :param mu: expected value :param data: data value """ plt.hist2d( np.log10(data.flatten() + 1), np.log10(mu.flatten() + 1), bins=50, norm=matplotlib.colors.LogNorm(), ) plt.gca().set_aspect("equal", adjustable="box") plt.xlabel("Data, log10") plt.ylabel("Posterior expected value, log10") plt.title("Reconstruction accuracy") plt.tight_layout()
[docs] def plot_history(self, iter_start=0, iter_end=-1, ax=None): r"""Plot training history Parameters ---------- iter_start omit initial iterations from the plot iter_end omit last iterations from the plot ax matplotlib axis """ if ax is None: ax = plt.gca() if iter_end == -1: iter_end = len(self.history_["elbo_train"]) ax.plot( np.array(self.history_["elbo_train"].index[iter_start:iter_end]), np.array(self.history_["elbo_train"].values.flatten())[iter_start:iter_end], label="train", ) ax.legend() ax.set_xlim(0, len(self.history_["elbo_train"])) ax.set_xlabel("Training epochs") ax.set_ylabel("-ELBO loss") plt.tight_layout()
def _export2adata(self, samples): r""" Export key model variables and samples Parameters ---------- samples dictionary with posterior mean, 5%/95% quantiles, SD, samples, generated by ``.sample_posterior()`` Returns ------- Updated dictionary with additional details is saved to ``adata.uns['mod']``. """ # add factor filter and samples of all parameters to unstructured data results = { "model_name": str(self.module.__class__.__name__), "date": str(date.today()), "factor_filter": list(getattr(self, "factor_filter", [])), "factor_names": list(self.factor_names_), "var_names": self.adata.var_names.tolist(), "obs_names": self.adata.obs_names.tolist(), "post_sample_means": samples["post_sample_means"] if "post_sample_means" in samples else None, "post_sample_stds": samples["post_sample_stds"] if "post_sample_stds" in samples else None, } # add posterior quantiles for k, v in samples.items(): if k.startswith("post_sample_"): results[k] = v if type(self.factor_names_) is dict: results["factor_names"] = self.factor_names_ return results
[docs] def sample2df_obs( self, samples: dict, site_name: str = "w_sf", summary_name: str = "means", name_prefix: str = "cell_abundance", factor_names_key: str = "", ): """Export posterior distribution summary for observation-specific parameters (e.g. spatial cell abundance) as Pandas data frame (means, 5%/95% quantiles or sd of posterior distribution). Parameters ---------- samples dictionary with posterior mean, 5%/95% quantiles, SD, samples, generated by ``.sample_posterior()`` site_name name of the model parameter to be exported summary_name posterior distribution summary to return ['means', 'stds', 'q05', 'q95'] name_prefix prefix to add to column names (f'{summary_name}{name_prefix}_{site_name}_{self\.factor_names_}') Returns ------- Pandas data frame corresponding to either means, 5%/95% quantiles or sd of the posterior distribution """ if type(self.factor_names_) is dict: factor_names_ = self.factor_names_[factor_names_key] else: factor_names_ = self.factor_names_ return pd.DataFrame( samples[f"post_sample_{summary_name}"].get(site_name, None), index=self.adata.obs_names, columns=[f"{summary_name}{name_prefix}_{site_name}_{i}" for i in factor_names_], )
[docs] def sample2df_vars( self, samples: dict, site_name: str = "gene_factors", summary_name: str = "means", name_prefix: str = "", factor_names_key: str = "", ): r"""Export posterior distribution summary for variable-specific parameters as Pandas data frame (means, 5%/95% quantiles or sd of posterior distribution). Parameters ---------- samples dictionary with posterior mean, 5%/95% quantiles, SD, samples, generated by ``.sample_posterior()`` site_name name of the model parameter to be exported summary_name posterior distribution summary to return ('means', 'stds', 'q05', 'q95') name_prefix prefix to add to column names (f'{summary_name}{name_prefix}_{site_name}_{self\.factor_names_}') Returns ------- Pandas data frame corresponding to either means, 5%/95% quantiles or sd of the posterior distribution """ if type(self.factor_names_) is dict: factor_names_ = self.factor_names_[factor_names_key] else: factor_names_ = self.factor_names_ site = samples[f"post_sample_{summary_name}"].get(site_name, None) return pd.DataFrame( site, columns=self.adata.var_names, index=[f"{summary_name}{name_prefix}_{site_name}_{i}" for i in factor_names_], ).T
[docs] def plot_QC(self, summary_name: str = "means", use_n_obs: int = 1000): """ Show quality control plots: 1. Reconstruction accuracy to assess if there are any issues with model training. The plot should be roughly diagonal, strong deviations signal problems that need to be investigated. Plotting is slow because expected value of mRNA count needs to be computed from model parameters. Random observations are used to speed up computation. Parameters ---------- summary_name posterior distribution summary to use ('means', 'stds', 'q05', 'q95') Returns ------- """ if getattr(self, "samples", False) is False: raise RuntimeError("self.samples is missing, please run self.export_posterior() first") if use_n_obs is not None: ind_x = np.random.choice( self.adata_manager.adata.n_obs, np.min((use_n_obs, self.adata.n_obs)), replace=False ) else: ind_x = None self.expected_nb_param = self.module.model.compute_expected( self.samples[f"post_sample_{summary_name}"], self.adata_manager, ind_x=ind_x ) x_data = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)[ind_x, :] if issparse(x_data): x_data = np.asarray(x_data.toarray()) self.plot_posterior_mu_vs_data(self.expected_nb_param["mu"], x_data)
[docs]class PyroAggressiveConvergence(Callback): """ A callback to compute/apply aggressive training convergence criteria for amortised inference. Motivated by this paper: https://arxiv.org/pdf/1901.05534.pdf """ def __init__(self, dataloader: AnnDataLoader = None, patience: int = 10, tolerance: float = 1e-4) -> None: super().__init__() self.dataloader = dataloader self.patience = patience self.tolerance = tolerance
[docs] def on_train_epoch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", unused: Optional = None ) -> None: """ Compute aggressive training convergence criteria for amortised inference. """ pyro_guide = pl_module.module.guide if hasattr(pyro_guide, "mutual_information"): if self.dataloader is None: dl = trainer.datamodule.train_dataloader() else: dl = self.dataloader for tensors in dl: tens = {k: t.to(pl_module.device) for k, t in tensors.items()} args, kwargs = pl_module.module._get_fn_args_from_batch(tens) break mi_ = pyro_guide.mutual_information(*args, **kwargs) mi_ = np.array([v for v in mi_.values()]).sum() pl_module.log("MI", mi_, prog_bar=True) if len(pl_module.mi) > 1: if pl_module.mi[-1] >= (mi_ - self.tolerance): pl_module.n_epochs_patience += 1 else: pl_module.n_epochs_patience = 0 if pl_module.n_epochs_patience > self.patience: # stop aggressive training by setting epoch counter to max epochs # pl_module.aggressive_epochs_counter = pl_module.n_aggressive_epochs + 1 logger.info('Stopped aggressive training after "{}" epochs'.format(pl_module.aggressive_epochs_counter)) pl_module.mi.append(mi_)
[docs]class PyroTrainingPlan(PyroTrainingPlan_scvi):
[docs] def training_epoch_end(self, outputs): """Training epoch end for Pyro training.""" elbo = 0 n = 0 for out in outputs: elbo += out["loss"] n += 1 elbo /= n self.log("elbo_train", elbo, prog_bar=True) gc.collect() torch.cuda.empty_cache()
[docs]class PyroAggressiveTrainingPlan1(PyroTrainingPlan_scvi): """ Lightning module task to train Pyro scvi-tools modules. Parameters ---------- pyro_module An instance of :class:`~scvi.module.base.PyroBaseModuleClass`. This object should have callable `model` and `guide` attributes or methods. loss_fn A Pyro loss. Should be a subclass of :class:`~pyro.infer.ELBO`. If `None`, defaults to :class:`~pyro.infer.Trace_ELBO`. optim A Pyro optimizer instance, e.g., :class:`~pyro.optim.Adam`. If `None`, defaults to :class:`pyro.optim.Adam` optimizer with a learning rate of `1e-3`. optim_kwargs Keyword arguments for **default** optimiser :class:`pyro.optim.Adam`. n_aggressive_epochs Number of epochs in aggressive optimisation of amortised variables. n_aggressive_steps Number of steps to spend optimising amortised variables before one step optimising global variables. n_steps_kl_warmup Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when `n_epochs_kl_warmup` is set to None. n_epochs_kl_warmup Number of epochs to scale weight on KL divergences from 0 to 1. Overrides `n_steps_kl_warmup` when both are not `None`. """ def __init__( self, pyro_module: PyroBaseModuleClass, loss_fn: Optional[pyro.infer.ELBO] = None, optim: Optional[pyro.optim.PyroOptim] = None, optim_kwargs: Optional[dict] = None, n_aggressive_epochs: int = 1000, n_aggressive_steps: int = 20, n_steps_kl_warmup: Union[int, None] = None, n_epochs_kl_warmup: Union[int, None] = 400, aggressive_vars: Union[list, None] = None, invert_aggressive_selection: bool = False, ): super().__init__( pyro_module=pyro_module, loss_fn=loss_fn, optim=optim, optim_kwargs=optim_kwargs, n_steps_kl_warmup=n_steps_kl_warmup, n_epochs_kl_warmup=n_epochs_kl_warmup, ) self.n_aggressive_epochs = n_aggressive_epochs self.n_aggressive_steps = n_aggressive_steps self.aggressive_steps_counter = 0 self.aggressive_epochs_counter = 0 self.mi = [] self.n_epochs_patience = 0 # in list not provided use amortised variables for aggressive training if aggressive_vars is None: aggressive_vars = list(self.module.list_obs_plate_vars["sites"].keys()) aggressive_vars = aggressive_vars + [f"{i}_initial" for i in aggressive_vars] aggressive_vars = aggressive_vars + [f"{i}_unconstrained" for i in aggressive_vars] self.aggressive_vars = aggressive_vars self.invert_aggressive_selection = invert_aggressive_selection # keep frozen variables as frozen self.requires_grad_false_vars = [k for k, v in self.module.guide.named_parameters() if not v.requires_grad] + [ k for k, v in self.module.model.named_parameters() if not v.requires_grad ] self.svi = pyro.infer.SVI( model=pyro_module.model, guide=pyro_module.guide, optim=self.optim, loss=self.loss_fn, )
[docs] def change_requires_grad(self, aggressive_vars_status, non_aggressive_vars_status): for k, v in self.module.guide.named_parameters(): if not np.any([i in k for i in self.requires_grad_false_vars]): k_in_vars = np.any([i in k for i in self.aggressive_vars]) # hide variables on the list if they are not hidden if k_in_vars and v.requires_grad and (aggressive_vars_status == "hide"): v.requires_grad = False # expose variables on the list if they are hidden if k_in_vars and (not v.requires_grad) and (aggressive_vars_status == "expose"): v.requires_grad = True # hide variables not on the list if they are not hidden if (not k_in_vars) and v.requires_grad and (non_aggressive_vars_status == "hide"): v.requires_grad = False # expose variables not on the list if they are hidden if (not k_in_vars) and (not v.requires_grad) and (non_aggressive_vars_status == "expose"): v.requires_grad = True for k, v in self.module.model.named_parameters(): if not np.any([i in k for i in self.requires_grad_false_vars]): k_in_vars = np.any([i in k for i in self.aggressive_vars]) # hide variables on the list if they are not hidden if k_in_vars and v.requires_grad and (aggressive_vars_status == "hide"): v.requires_grad = False # expose variables on the list if they are hidden if k_in_vars and (not v.requires_grad) and (aggressive_vars_status == "expose"): v.requires_grad = True # hide variables not on the list if they are not hidden if (not k_in_vars) and v.requires_grad and (non_aggressive_vars_status == "hide"): v.requires_grad = False # expose variables not on the list if they are hidden if (not k_in_vars) and (not v.requires_grad) and (non_aggressive_vars_status == "expose"): v.requires_grad = True
[docs] def training_epoch_end(self, outputs): self.aggressive_epochs_counter += 1 self.change_requires_grad( aggressive_vars_status="expose", non_aggressive_vars_status="expose", ) elbo = 0 n = 0 for out in outputs: elbo += out["loss"] n += 1 elbo /= n self.log("elbo_train", elbo, prog_bar=True) gc.collect() torch.cuda.empty_cache()
[docs] def training_step(self, batch, batch_idx): args, kwargs = self.module._get_fn_args_from_batch(batch) # Set KL weight if necessary. # Note: if applied, ELBO loss in progress bar is the effective KL annealed loss, not the true ELBO. if self.use_kl_weight: kwargs.update({"kl_weight": self.kl_weight}) if self.aggressive_epochs_counter < self.n_aggressive_epochs: if self.aggressive_steps_counter < self.n_aggressive_steps: self.aggressive_steps_counter += 1 # Do parameter update exclusively for amortised variables if self.invert_aggressive_selection: self.change_requires_grad( aggressive_vars_status="hide", non_aggressive_vars_status="expose", ) else: self.change_requires_grad( aggressive_vars_status="expose", non_aggressive_vars_status="hide", ) loss = torch.Tensor([self.svi.step(*args, **kwargs)]) else: self.aggressive_steps_counter = 0 # Do parameter update exclusively for non-amortised variables if self.invert_aggressive_selection: self.change_requires_grad( aggressive_vars_status="expose", non_aggressive_vars_status="hide", ) else: self.change_requires_grad( aggressive_vars_status="hide", non_aggressive_vars_status="expose", ) loss = torch.Tensor([self.svi.step(*args, **kwargs)]) else: # Do parameter update for both types of variables self.change_requires_grad( aggressive_vars_status="expose", non_aggressive_vars_status="expose", ) loss = torch.Tensor([self.svi.step(*args, **kwargs)]) return {"loss": loss}
[docs]class PyroAggressiveTrainingPlan(PyroAggressiveTrainingPlan1): """ Lightning module task to train Pyro scvi-tools modules. Parameters ---------- pyro_module An instance of :class:`~scvi.module.base.PyroBaseModuleClass`. This object should have callable `model` and `guide` attributes or methods. loss_fn A Pyro loss. Should be a subclass of :class:`~pyro.infer.ELBO`. If `None`, defaults to :class:`~pyro.infer.Trace_ELBO`. optim A Pyro optimizer instance, e.g., :class:`~pyro.optim.Adam`. If `None`, defaults to :class:`pyro.optim.Adam` optimizer with a learning rate of `1e-3`. optim_kwargs Keyword arguments for **default** optimiser :class:`pyro.optim.Adam`. n_steps_kl_warmup Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when `n_epochs_kl_warmup` is set to None. n_epochs_kl_warmup Number of epochs to scale weight on KL divergences from 0 to 1. Overrides `n_steps_kl_warmup` when both are not `None`. """ def __init__( self, scale_elbo: Union[float, None] = 1.0, **kwargs, ): super().__init__(**kwargs) if scale_elbo != 1.0: self.svi = pyro.infer.SVI( model=poutine.scale(self.module.model, scale_elbo), guide=poutine.scale(self.module.guide, scale_elbo), optim=self.optim, loss=self.loss_fn, ) else: self.svi = pyro.infer.SVI( model=self.module.model, guide=self.module.guide, optim=self.optim, loss=self.loss_fn, )