class, n_hidden=256, n_latent=10, n_layers=1, dropout_rate=0.1, gamma_init_data=False, linear_decoder=False, **model_kwargs)

Velocity Variational Inference.

See [Gayoso et al., 2023] for details.

adata : AnnData

AnnData object that has been registered via setup_anndata().

n_hidden : int

Number of nodes per hidden layer.

n_latent : int

Dimensionality of the latent space.

n_layers : int

Number of hidden layers used for encoder and decoder NNs.

dropout_rate : float

Dropout rate for neural networks.

gamma_init_data : bool

Initialize gamma using the data-driven technique.

linear_decoder : bool

Use a linear decoder from latent space to time.


Keyword args for VELOVAE



Data attached to model instance.


Manager instance associated with self.adata.


The current device that the module’s params are on.


Returns computed metrics during training.


Whether the model has been trained.


Observations that are in test set.


Observations that are in train set.


Observations that are in validation set.


convert_legacy_save(dir_path, output_dir_path)

Converts a legacy saved model (<v0.15.0) to the updated save format.


Deregisters the AnnDataManager instance associated with adata.

get_anndata_manager(adata[, required])

Retrieves the AnnDataManager for a given AnnData object specific to this model instance.

get_directional_uncertainty([adata, …])


get_elbo([adata, indices, batch_size])

Return the ELBO for the data.

get_expression_fit([adata, indices, …])

Returns the fitted spliced and unspliced abundance (s(t) and u(t)).

get_from_registry(adata, registry_key)

Returns the object in AnnData associated with the key in the data registry.

get_gene_likelihood([adata, indices, …])

Returns the likelihood per gene.

get_latent_representation([adata, indices, …])

Return the latent representation for each cell.

get_latent_time([adata, indices, gene_list, …])

Returns the cells by genes latent time.

get_marginal_ll([adata, indices, …])

Return the marginal LL for the data.

get_permutation_scores(labels_key[, adata])

Compute permutation scores.



get_reconstruction_error([adata, indices, …])

Return the reconstruction error for the data.

get_state_assignment([adata, indices, …])

Returns cells by genes by states probabilities.

get_velocity([adata, indices, gene_list, …])

Returns cells by genes velocity estimates.

load(dir_path[, adata, use_gpu, …])

Instantiate a model from the saved output.

load_registry(dir_path[, prefix])

Return the full registry saved with the model.


Registers an AnnDataManager instance with this model class.

save(dir_path[, prefix, overwrite, save_anndata])

Save the state of the model.

setup_anndata(adata, spliced_layer, …)

Sets up the AnnData object for this model.


Move model to device.

train([max_epochs, lr, weight_decay, …])

Train the model.

view_anndata_setup([adata, …])

Print summary of the setup for the initial AnnData or a given AnnData object.

view_setup_args(dir_path[, prefix])

Print args used to setup a saved model.