Pymc3 implementation (advanced use)¶
Main models: general and Nanostring WTA¶
Pymc3 (main cell2location model): LocationModelLinearDependentWMultiExperimentLocationBackgroundNormGeneAlpha¶
Nanostring WTA model: LocationModelWTA¶
Models with simplified architecture¶
No normalisation: LocationModelLinearDependentWMultiExperiment¶
No prior factorisation of w_sf (but hierarchical priors): LocationModelHierarchicalWMultiExperiment¶
No prior factorisation of w_sf: LocationModelMultiExperiment¶
No gene-specific platform effect m_g: LocationModelLinearDependentWMultiExperimentNoMg¶
No additive background RNA: LocationModelLinearDependentWMultiExperimentNoSegLs¶
Base model classes (infrastructure)¶
Pymc3: BaseModel¶
Base model class
-
class
cell2location.models.base.base_model.
BaseModel
(X_data: numpy.ndarray, n_fact: int = 10, data_type: str = 'float32', n_iter: int = 200000, learning_rate=0.001, total_grad_norm_constraint=200, verbose=True, var_names=None, var_names_read=None, obs_names=None, fact_names=None, sample_id=None)[source]¶ Bases:
object
Base class for pymc3 and pyro models.
- Parameters
X_data – Numpy array of gene expression (cols) in spatial locations (rows)
n_fact – Number of factors
n_iter – Number of training iterations
learning_rate – ADAM learning rate for optimising Variational inference objective
data_type – theano data type used to store parameters (‘float32’ for single, ‘float64’ for double precision)
total_grad_norm_constraint – gradient constraints in optimisation
verbose – print diagnostic messages?
var_names – Variable names (e.g. gene identifiers)
var_names_read – Readable variable names (e.g. gene symbol)
obs_names – Observation names (e.g. cell or spot id)
fact_names – Factor names
sample_id – Sample identifiers (e.g. different experiments)
-
plot_prior_vs_data
(data_target_name='data_target', data_node='X_data', log_transform=True)[source]¶ Plot data vs a single sample from the prior in a 2D histogram Uses self.X_data and self.prior_trace[‘data_target’]. :param data_node: name of the object slot containing data
-
static
align_plot_stability
(fac1, fac2, name1, name2, align=True, return_aligned=False)[source]¶ Align columns between two np.ndarrays using scipy.optimize.linear_sum_assignment, then plot correlations between columns in fac1 and fac2, ordering fac2 according to alignment
- Parameters
fac1 – np.ndarray 1, factors in columns
fac2 – np.ndarray 2, factors in columns
name1 – axis x name
name2 – axis y name
align – boolean, match columns in fac1 and fac2 using linear_sum_assignment?
-
generate_cv_data
(n: int = 2, discrete: bool = True, non_discrete_mean_var: float = 1)[source]¶ Generate X_data for molecular cross-validation by sampling molecule counts with np.random.binomial
- Parameters
n – number of cross-validation folds of equal size to generate, for now, only n=2 is implemented
-
bootstrap_data
(n=10, downsampling_p=0.8, discrete=True, non_discrete_mean_var=1)[source]¶ Generate X_data for bootstrap analysis by sampling molecule counts with np.random.binomial
- Parameters
n – number of bootstrap samples to generate
downsampling_p – sample this proportion of values
non_discrete_mean_var – low means lower variance
-
plot_posterior_mu_vs_data
(mu_node_name='mu', data_node='X_data')[source]¶ Plot expected value of the model (e.g. mean of poisson distribution)
- Parameters
mu_node_name – name of the object slot containing expected value
data_node – name of the object slot containing data
-
plot_history
(iter_start=0, iter_end=- 1, mean_field_slot=None, log_y=True, ax=None)[source]¶ Plot training history
- Parameters
iter_start – omit initial iterations from the plot
iter_end – omit last iterations from the plot
-
plot_validation_history
(start_step=0, end_step=- 1, mean_field_slot='init_1', log_y=True, ax=None)[source]¶ Plot model loss (NB likelihood + penalty) using the model on training and validation data
-
set_fact_filt
(fact_filt)[source]¶ Specify which factors are not relevant/ not expressed. It is currently used to filter results shown by .print_gene_loadings() and .plot_gene_loadings()
- Parameters
fact_filt – logical array specifying which factors are to be retained
-
apply_fact_filt
(df)[source]¶ Select DataFrame columns by factor filter which was saved in the model object
- Parameters
df – pd.DataFrame
-
print_gene_loadings
(gene_fact_name='gene_factors', loadings_attr='gene_loadings', top_n=10, gene_filt=None, fact_filt=None)[source]¶ Print top-10 genes for each factor in gene loadings matrix.
- Parameters
gene_fact_name – model parameter name to extract from samples if self.gene_loadings doesn’t exist
loadings_attr – model object attribute name that stores loadings
top_n – number of genes to plot for each factor
gene_filt – boolean filter for genes (e.g. restrict printed markers to TFs)
fact_filt – boolean filter for factors
-
plot_gene_loadings
(sel_var_names, var_names, gene_fact_name='gene_factors', loadings_attr='gene_loadings', figsize=15, 7, cluster_factors=False, cluster_genes=True, cmap='viridis', title='', fact_filt=None, fun_type='heatmap', return_linkage=False)[source]¶ Plot gene loadings as a heatmap
- Parameters
sel_var_names – list of variable names to select
var_names – sel_var_names matches some names in var_names which identifies each gene in gene loadings
gene_fact_name – model parameter name to extract from samples if self.gene_loadings doesn’t exist
figsize – histogram figure size
cluster_factors – hierarchically cluster factors?
cluster_genes – hierarchically cluster genes?
cmap – matplotlib colormap
title – plot title
fact_filt – boolean or character filter for factors
-
plot_loading_distribution
(loadings_name='gene_factors', loadings=None)[source]¶ Plot histogram for each loading (column-wise)
- Parameters
loadings_name – character name to be extracted from self.samples[‘post_sample_means’]
loadings – np.ndarray to be plotted column-wise. Supersedes loadings_name.
-
factor_expressed_plot
(shape_cut=4, rate_cut=15, sample_type='post_sample_means', shape='cell_fact_shape_hyp', rate='cell_fact_rate_hyp', shape_lab='cell_factors, Gamma shape', rate_lab='cell_factors, Gamma rate', invert_selection=False)[source]¶ Show which factors are expressed on a scatterplot of their regularising priors
- Parameters
shape_cut – Gamma shape cutoff below which factors are expressed
rate_cut – Gamma rate cutoff below which factors are expressed
sample_type – which posterior summary to look at, default ‘post_sample_means’
shape – parameter name for the Gamma shape of each factor, default ‘cell_fact_mu_hyp’
rate – parameter name for the Gamma rate of each factor, default ‘cell_fact_sd_hyp’
shape_lab – axis label for shape
rate_lab – axis label for rate
invert_selection – if values below cutoffs are for not expressed, set invert_selection to True.