Reference implementation
Code
The PyTorch implementation is released through the GAHIB repository with a Scanpy-style workflow for single-cell AnnData inputs.
Install
Editable package
pip install -e .
# optional extras
pip install -e ".[all]"
pip install -e ".[dev]"Python 3.9 or later. Full dependency list is maintained in the repository package metadata.
Outputs
Model accessors
model.get_latent()returns the learned embedding.model.get_time()returns the trajectory time estimate when trajectory modules are enabled.model.get_centroid()returns per-cluster latent centroids.
Quick start
Fit GAHIB on an AnnData object
The graph encoder is selected explicitly; set encoder_type to mlp for a plain VAE without graph message passing.
import scanpy as sc
from gahib import GAHIB
adata = sc.read_h5ad("data/human_cd34_bone_marrow.h5ad")
model = GAHIB(
adata,
layer="counts",
encoder_type="graph",
graph_type="GAT",
latent_dim=10,
i_dim=2,
irecon=1.0,
lorentz=5.0,
)
model.fit(epochs=200, patience=30)
Z = model.get_latent()Notebooks
A Colab demo notebook with a frozen checkpoint will be released alongside the paper. Subscribe to the GitHub repository for the announcement.
License
The reference implementation is distributed under the MIT License.
Continue