Pymc3 implementation (advanced use)

Pipelines (wrappers for full workflow)

Run cell2location

Run regression

Main models: general and Nanostring WTA

Pymc3 (main cell2location model): LocationModelLinearDependentWMultiExperimentLocationBackgroundNormGeneAlpha

Nanostring WTA model: LocationModelWTA

Models with simplified architecture

No normalisation: LocationModelLinearDependentWMultiExperiment

No prior factorisation of w_sf (but hierarchical priors): LocationModelHierarchicalWMultiExperiment

No prior factorisation of w_sf: LocationModelMultiExperiment

No gene-specific platform effect m_g: LocationModelLinearDependentWMultiExperimentNoMg

No additive background RNA: LocationModelLinearDependentWMultiExperimentNoSegLs

Base model classes (infrastructure)

Pymc3: BaseModel

Base model class

class cell2location.models.base.base_model.BaseModel(X_data: numpy.ndarray, n_fact: int = 10, data_type: str = 'float32', n_iter: int = 200000, learning_rate=0.001, total_grad_norm_constraint=200, verbose=True, var_names=None, var_names_read=None, obs_names=None, fact_names=None, sample_id=None)[source]

Bases: object

Base class for pymc3 and pyro models.

Parameters
  • X_data – Numpy array of gene expression (cols) in spatial locations (rows)

  • n_fact – Number of factors

  • n_iter – Number of training iterations

  • learning_rate – ADAM learning rate for optimising Variational inference objective

  • data_type – theano data type used to store parameters (‘float32’ for single, ‘float64’ for double precision)

  • total_grad_norm_constraint – gradient constraints in optimisation

  • verbose – print diagnostic messages?

  • var_names – Variable names (e.g. gene identifiers)

  • var_names_read – Readable variable names (e.g. gene symbol)

  • obs_names – Observation names (e.g. cell or spot id)

  • fact_names – Factor names

  • sample_id – Sample identifiers (e.g. different experiments)

plot_prior_vs_data(data_target_name='data_target', data_node='X_data', log_transform=True)[source]

Plot data vs a single sample from the prior in a 2D histogram Uses self.X_data and self.prior_trace[‘data_target’]. :param data_node: name of the object slot containing data

static align_plot_stability(fac1, fac2, name1, name2, align=True, return_aligned=False)[source]

Align columns between two np.ndarrays using scipy.optimize.linear_sum_assignment, then plot correlations between columns in fac1 and fac2, ordering fac2 according to alignment

Parameters
  • fac1 – np.ndarray 1, factors in columns

  • fac2 – np.ndarray 2, factors in columns

  • name1 – axis x name

  • name2 – axis y name

  • align – boolean, match columns in fac1 and fac2 using linear_sum_assignment?

generate_cv_data(n: int = 2, discrete: bool = True, non_discrete_mean_var: float = 1)[source]

Generate X_data for molecular cross-validation by sampling molecule counts with np.random.binomial

Parameters

n – number of cross-validation folds of equal size to generate, for now, only n=2 is implemented

bootstrap_data(n=10, downsampling_p=0.8, discrete=True, non_discrete_mean_var=1)[source]

Generate X_data for bootstrap analysis by sampling molecule counts with np.random.binomial

Parameters
  • n – number of bootstrap samples to generate

  • downsampling_p – sample this proportion of values

  • non_discrete_mean_var – low means lower variance

plot_posterior_mu_vs_data(mu_node_name='mu', data_node='X_data')[source]

Plot expected value of the model (e.g. mean of poisson distribution)

Parameters
  • mu_node_name – name of the object slot containing expected value

  • data_node – name of the object slot containing data

plot_history(iter_start=0, iter_end=- 1, mean_field_slot=None, log_y=True, ax=None)[source]

Plot training history

Parameters
  • iter_start – omit initial iterations from the plot

  • iter_end – omit last iterations from the plot

plot_validation_history(start_step=0, end_step=- 1, mean_field_slot='init_1', log_y=True, ax=None)[source]

Plot model loss (NB likelihood + penalty) using the model on training and validation data

plot_posterior_vs_data(gene_fact_name='gene_factors', cell_fact_name='cell_factors')[source]
set_fact_filt(fact_filt)[source]

Specify which factors are not relevant/ not expressed. It is currently used to filter results shown by .print_gene_loadings() and .plot_gene_loadings()

Parameters

fact_filt – logical array specifying which factors are to be retained

apply_fact_filt(df)[source]

Select DataFrame columns by factor filter which was saved in the model object

Parameters

df – pd.DataFrame

print_gene_loadings(gene_fact_name='gene_factors', loadings_attr='gene_loadings', top_n=10, gene_filt=None, fact_filt=None)[source]

Print top-10 genes for each factor in gene loadings matrix.

Parameters
  • gene_fact_name – model parameter name to extract from samples if self.gene_loadings doesn’t exist

  • loadings_attr – model object attribute name that stores loadings

  • top_n – number of genes to plot for each factor

  • gene_filt – boolean filter for genes (e.g. restrict printed markers to TFs)

  • fact_filt – boolean filter for factors

plot_gene_loadings(sel_var_names, var_names, gene_fact_name='gene_factors', loadings_attr='gene_loadings', figsize=15, 7, cluster_factors=False, cluster_genes=True, cmap='viridis', title='', fact_filt=None, fun_type='heatmap', return_linkage=False)[source]

Plot gene loadings as a heatmap

Parameters
  • sel_var_names – list of variable names to select

  • var_namessel_var_names matches some names in var_names which identifies each gene in gene loadings

  • gene_fact_name – model parameter name to extract from samples if self.gene_loadings doesn’t exist

  • figsize – histogram figure size

  • cluster_factors – hierarchically cluster factors?

  • cluster_genes – hierarchically cluster genes?

  • cmap – matplotlib colormap

  • title – plot title

  • fact_filt – boolean or character filter for factors

plot_loading_distribution(loadings_name='gene_factors', loadings=None)[source]

Plot histogram for each loading (column-wise)

Parameters
  • loadings_name – character name to be extracted from self.samples[‘post_sample_means’]

  • loadings – np.ndarray to be plotted column-wise. Supersedes loadings_name.

factor_expressed_plot(shape_cut=4, rate_cut=15, sample_type='post_sample_means', shape='cell_fact_shape_hyp', rate='cell_fact_rate_hyp', shape_lab='cell_factors, Gamma shape', rate_lab='cell_factors, Gamma rate', invert_selection=False)[source]

Show which factors are expressed on a scatterplot of their regularising priors

Parameters
  • shape_cut – Gamma shape cutoff below which factors are expressed

  • rate_cut – Gamma rate cutoff below which factors are expressed

  • sample_type – which posterior summary to look at, default ‘post_sample_means’

  • shape – parameter name for the Gamma shape of each factor, default ‘cell_fact_mu_hyp’

  • rate – parameter name for the Gamma rate of each factor, default ‘cell_fact_sd_hyp’

  • shape_lab – axis label for shape

  • rate_lab – axis label for rate

  • invert_selection – if values below cutoffs are for not expressed, set invert_selection to True.

plot_reconstruction_history(n_type='cv', start_step=0, end_step=45)[source]

Plot reconstruction error using the model on training and validation data

export2adata(adata, slot_name='mod')[source]

Add posterior mean and sd for all parameters to unstructured data adata.uns[‘mod’].

Parameters

adata – anndata object

Pymc3: Pymc3LocModel

Pymc3: Pymc3Model