Preview notice. This site includes method notes, datasets, metrics, and code; results and weights are not included.

Reference implementation

Code

The PyTorch implementation is released through the GAHIB repository with a Scanpy-style workflow for single-cell AnnData inputs.

Install

Editable package

pip install -e .
# optional extras
pip install -e ".[all]"
pip install -e ".[dev]"

Python 3.9 or later. Full dependency list is maintained in the repository package metadata.

Outputs

Model accessors

  • model.get_latent() returns the learned embedding.
  • model.get_time() returns the trajectory time estimate when trajectory modules are enabled.
  • model.get_centroid() returns per-cluster latent centroids.

Quick start

Fit GAHIB on an AnnData object

The graph encoder is selected explicitly; set encoder_type to mlp for a plain VAE without graph message passing.
import scanpy as sc
from gahib import GAHIB

adata = sc.read_h5ad("data/human_cd34_bone_marrow.h5ad")

model = GAHIB(
  adata,
  layer="counts",
  encoder_type="graph",
  graph_type="GAT",
  latent_dim=10,
  i_dim=2,
  irecon=1.0,
  lorentz=5.0,
)

model.fit(epochs=200, patience=30)
Z = model.get_latent()

Notebooks

A Colab demo notebook with a frozen checkpoint will be released alongside the paper. Subscribe to the GitHub repository for the announcement.

License

The reference implementation is distributed under the MIT License.

Continue