GPLVM Analysis
SpikeLab can fit a Gaussian Process Latent Variable Model (GPLVM) to spike train data, uncovering discrete latent states that describe the network’s activity over time (Zheng et al. 2025). This is useful for identifying recurring population dynamics such as UP/DOWN states or multi-state transitions.
Zheng, Z., Zutshi, I., Huszar, R. et al. From labels to latents: revealing state-dependent hippocampal computations with Jump Latent Variable Model. bioRxiv (2025).
Note
GPLVM fitting requires the optional poor_man_gplvm and jax
packages. Install them separately before using these features.
Fitting a GPLVM
fit_gplvm() bins the spike trains and fits a GPLVM
via expectation-maximisation:
result = sd.fit_gplvm(
bin_size_ms=50.0, # temporal bin size for spike counts
n_latent_bin=100, # number of discrete latent states
n_iter=20, # EM iterations
movement_variance=1.0, # state transition variance
tuning_lengthscale=10.0, # GP tuning curve lengthscale
random_seed=3,
)
The returned dict contains:
decode_res— the full decode result from the GPLVM model, including posterior marginals over latent states and dynamics.reorder_indices— unit ordering derived from the model’s tuning curves.model— the fitted model object.binned_spike_counts— the(T, N)binned spike matrix used for fitting.bin_size_ms— the bin size used.log_marginal_l— log marginal likelihood per EM iteration.
All arrays are returned as NumPy ndarrays (not JAX arrays).
Post-fit Analysis
SpikeLab provides utility functions that extract summary statistics from the GPLVM decode result.
State entropy
Compute the Shannon entropy of the latent state posterior at each time bin. Higher entropy indicates greater uncertainty about the current state:
from spikelab.spikedata.utils import gplvm_state_entropy
entropy = gplvm_state_entropy(
result["decode_res"].posterior_latent_marg,
)
# entropy.shape == (T,)
Continuity probability
Extract the probability that the network remains in the same state from one time bin to the next (i.e. no state transition):
from spikelab.spikedata.utils import gplvm_continuity_prob
cont_prob = gplvm_continuity_prob(result["decode_res"])
# cont_prob.shape == (T,)
Average state probability
Compute the mean probability of each latent state across all time bins. This reveals which states dominate the recording:
from spikelab.spikedata.utils import gplvm_average_state_probability
avg_prob = gplvm_average_state_probability(
result["decode_res"].posterior_latent_marg,
)
# avg_prob.shape == (K,), where K = n_latent_bin
Consecutive state durations
Measure how long the network stays in a given condition. For example, compute the durations of high-continuity epochs:
from spikelab.spikedata.utils import consecutive_durations
durations = consecutive_durations(
cont_prob,
threshold=0.8,
mode="above", # runs where cont_prob >= 0.8
min_dur=1, # minimum run length to include
)
# durations is a 1-D array of run lengths (in time bins)
Visualisation
The GPLVM results integrate with SpikeLab’s plotting utilities. Use
plot() with show_model_states=True to overlay
the decoded states on a raster plot:
fig = sd.plot(
show_raster=True,
show_pop_rate=True,
show_model_states=True,
gplvm_result=result,
)
For dimensionality reduction on the latent posteriors, use
PCA_reduction() or
UMAP_reduction() on the posterior marginal
matrix, and visualise with plot_manifold().