Callbacks
manyLatents has two callback systems: embedding callbacks for post-embedding processing, and trainer callbacks (Lightning) for training-time hooks.
Callback Hierarchy
BaseCallback (ABC)
├── on_experiment_start(cfg)
├── on_experiment_end()
├── on_latent_end(dataset, embeddings)
├── on_training_start()
└── on_training_end()
EmbeddingCallback(BaseCallback, ABC)
├── on_latent_end(dataset, embeddings) ← abstract, must implement
├── register_output(key, output) ← store results for downstream
└── callback_outputs: dict ← accumulated outputs
lightning.Callback ← PyTorch Lightning's callback
├── on_fit_start()
├── on_train_batch_end()
├── on_train_epoch_end()
├── on_validation_end()
└── ...
EmbeddingCallback runs after embeddings are computed (for both LatentModule and LightningModule). Lightning Callback runs during training (LightningModule only).
Instantiation
Callbacks are instantiated from two config groups and routed by type:
def instantiate_callbacks(trainer_cb_cfg, embedding_cb_cfg):
lightning_cbs, embedding_cbs = [], []
for name, cfg in trainer_cb_cfg.items():
cb = hydra.utils.instantiate(cfg)
if isinstance(cb, Callback):
lightning_cbs.append(cb)
for name, cfg in embedding_cb_cfg.items():
cb = hydra.utils.instantiate(cfg)
if isinstance(cb, EmbeddingCallback):
embedding_cbs.append(cb)
return lightning_cbs, embedding_cbs
Config Structure
# configs/callbacks/default.yaml
defaults:
- trainer: null # Lightning callbacks (probing, etc.)
- embedding: null # Embedding callbacks (save, plot, wandb)
- _self_
# configs/callbacks/embedding/default.yaml
defaults:
- save_embeddings
- plot_embeddings
- wandb_log_scores
- _self_
Execution Flow
In run_algorithm():
- Callbacks instantiated from config
- Lightning callbacks passed to
Trainer(callbacks=[...]) - Algorithm executes (fit/transform or trainer.fit)
- Embeddings wrapped as
LatentOutputsdict - Metrics evaluated
- Each embedding callback's
on_latent_end()called with dataset + embeddings - Callback outputs merged into the embeddings dict
for cb in embedding_cbs:
cb_result = cb.on_latent_end(dataset=datamodule.test_dataset, embeddings=embeddings)
if isinstance(cb_result, dict):
callback_outputs.update(cb_result)
LatentOutputs
The standard interchange format passed to all embedding callbacks:
LatentOutputs = dict[str, Any]
# Required: "embeddings" (np.ndarray) — shape (n, d)
# Optional: "label", "metadata", "scores", "callback_outputs"
# Custom keys: "trajectories", cluster assignments, velocity fields, etc.
validate_latent_outputs() checks the required key exists.
Output Types
Algorithms populate different keys depending on what they produce:
| Key | Shape | Description |
|---|---|---|
"embeddings" |
(n, d) |
Point cloud in latent space (default, chainable in pipelines) |
"trajectories" |
(n_bins, n_traj, d) |
Flow paths from trajectory inference methods |
"label" |
(n,) |
Ground truth labels |
"scores" |
dict |
Evaluation metrics |
"metadata" |
dict |
Algorithm info and runtime metadata |
The "embeddings" key is the standard primary output used by metrics, plotting, and pipeline chaining. Trajectory methods populate both "embeddings" (e.g., endpoint positions) and "trajectories" for the full flow data.
SaveEmbeddings automatically persists any additional keys as separate .npy or .json files when save_additional_outputs: true.
SaveEmbeddings
Saves embeddings to disk in CSV or NPY format. Optionally saves metric tables (scalar summary and per-sample).
# configs/callbacks/embedding/save_embeddings.yaml
save_embeddings:
_target_: manylatents.callbacks.embedding.save_embeddings.SaveEmbeddings
save_dir: ${hydra:runtime.output_dir}
save_format: "csv"
experiment_name: ${name}
save_metric_tables: false
| Parameter | Default | Description |
|---|---|---|
save_dir |
Hydra output dir | Base directory for saved files |
save_format |
"csv" |
Format: "csv" or "npy" |
save_metric_tables |
false |
Save separate scalar + per-sample metric CSVs |
save_additional_outputs |
false |
Save non-embedding keys as separate files |
When running under Geomancer orchestration, also writes to the shared metrics directory via atomic_writer.
PlotEmbeddings
Creates 2D scatter plots of embeddings with customizable colormaps and optional WandB upload.
# configs/callbacks/embedding/plot_embeddings.yaml
plot_embeddings:
_target_: manylatents.callbacks.embedding.plot_embeddings.PlotEmbeddings
save_dir: ${hydra:runtime.output_dir}
experiment_name: "${name}.png"
figsize: [8, 6]
label_col: Population
legend: false
color_by_score: null
Colormap Resolution
PlotEmbeddings resolves colormaps from multiple sources (highest priority first):
- User overrides —
cmap_override,is_categorical_overridein config - Metric-declared —
scores["<metric>__viz"]containing aColormapInfo - Dataset-provided — via the
ColormapProviderprotocol - Defaults —
"viridis"
Datasets can implement ColormapProvider to declare their preferred visualization:
class MyDataset(ColormapProvider):
def get_colormap_info(self) -> ColormapInfo:
return ColormapInfo(
cmap={"A": "#ff0000", "B": "#00ff00"},
label_names={0: "Class A", 1: "Class B"},
is_categorical=True,
)
Coloring by Score
Color points by a metric value instead of labels:
plot_embeddings:
color_by_score: "embedding.local_intrinsic_dimensionality"
legend: false # Uses colorbar instead
WandbLogScores
Logs metric scores to WandB in three formats:
# configs/callbacks/embedding/wandb_log_scores.yaml
wandb_log_scores:
_target_: manylatents.callbacks.embedding.wandb_log_scores.WandbLogScores
| Log Type | WandB Key | Content |
|---|---|---|
| Summary scalars | {tag}/metric_name |
0-D metrics as wandb.log() |
| Per-sample table | {tag}/per_sample_metrics |
1-D arrays as wandb.Table |
| k-curve tables | {tag}/metric__k_curve_table |
Swept n_neighbors values grouped into tables |
k-curve tables automatically detect metrics swept over n_neighbors (e.g., trustworthiness__n_neighbors_5, _10, _20) and group them into a single curve.
LoadingsAnalysisCallback
Analyzes shared vs modality-specific components in multi-modal loadings (e.g., DNA + RNA + Protein fusion).
callbacks:
embedding:
loadings:
_target_: manylatents.callbacks.embedding.loadings_analysis.LoadingsAnalysisCallback
modality_dims: [1920, 256, 1536]
modality_names: [dna, rna, protein]
threshold: 0.1
Requires the algorithm module to have a get_loadings() method (e.g., MergingModule with concat_pca).
Lightning Callbacks
Trainer callbacks extend lightning.Callback and run during the training loop. They are passed to Trainer(callbacks=[...]).
ActivationTrajectoryCallback
The primary trainer callback. Extracts activations from model layers at configurable triggers and computes diffusion operators to track representation geometry.
# configs/callbacks/trainer/probe.yaml
probe:
_target_: manylatents.lightning.callbacks.activation_tracker.ActivationTrajectoryCallback
layer_specs:
- _target_: manylatents.lightning.hooks.LayerSpec
path: "transformer.h[-1]"
extraction_point: "output"
reduce: "mean"
trigger:
_target_: manylatents.lightning.callbacks.activation_tracker.ProbeTrigger
every_n_steps: 500
on_checkpoint: true
on_validation_end: true
gauge:
_target_: manylatents.callbacks.diffusion_operator.DiffusionGauge
knn: 15
alpha: 1.0
symmetric: false
log_to_wandb: true
See Probing for full documentation of layer specs, triggers, gauge configuration, and programmatic access.
Adding a Trainer Callback
- Create a class extending
lightning.Callback - Implement the relevant hooks (
on_train_batch_end,on_train_epoch_end, etc.) - Create a config in
configs/callbacks/trainer/your_callback.yaml - Add to your experiment config:
callbacks:
trainer:
your_callback:
_target_: manylatents.your_module.YourCallback
param: value
Adding an Embedding Callback
- Create a class extending
EmbeddingCallback - Implement
on_latent_end(self, dataset, embeddings) -> Any - Use
self.register_output(key, value)to store results - Create a config in
configs/callbacks/embedding/your_callback.yaml - Add to the embedding defaults or your experiment config