04b: Neural SBI for Mutation-Selection Inference

Why Neural Methods?

In notebooks 02 and 03, we used ABC-SMC and Bayesian Synthetic Likelihood to infer the parameters of the Basener-Sanford mutation-selection model. Both methods work by running new simulations during inference — every time they evaluate a candidate parameter combination, they must simulate an entire population forward in time. ABC-SMC required hundreds of thousands of simulations; BSL required thousands per likelihood evaluation. Worse, this computational cost must be paid again from scratch for every new dataset.

This notebook takes a fundamentally different approach: Neural Posterior Estimation (NPE) and Flow-Matching Posterior Estimation (FMPE). These methods train a neural network once on many simulated datasets, then use the trained network to produce posterior distributions instantly for any new observation. This is called amortized inference — the heavy cost is paid once during training, and every subsequent inference is nearly free.

How Neural SBI Works — A Biological Analogy

ABC-SMC and BSL are like identifying an unknown species by catching specimens one at a time, comparing each to a field guide, and gradually narrowing down the possibilities. Every new specimen requires the same laborious process.

NPE and FMPE are like building an AI identification app by first showing it thousands of labeled photographs. After that one-time training effort, identification of any new specimen is instant.

The training process has three steps:

  1. Simulate a training set. Draw 10,000 random parameter combinations (μ, γshape, γscale, pbeneficial, σenv,ind) from the prior. For each, run the Basener-Sanford simulator and compute summary statistics of the resulting fitness trajectory. Each simulation represents a different hypothetical population — some thriving, some undergoing meltdown.
  2. Train the network. The neural network learns to recognize the "fingerprint" in the summary statistics that distinguishes, say, a high-mutation-rate population from a low-mutation-rate one, or a population with many beneficial mutations from one with few.
  3. Instant inference. Feed in the summary statistics from the actual observed fitness data. The network immediately outputs the posterior distribution — which parameter values are most consistent with what was observed.

The Amortization Advantage

Why does instant inference matter for studying mutation-selection dynamics? The extended Fisher's Fundamental Theorem tells us that a population's fate depends on whether Var(m) > μ|Eg[s]|b̅ — whether selection can overcome mutational load. To map this critical boundary across the full parameter space (notebook 06), we need to evaluate posteriors at thousands of different parameter combinations.

This also enables "what-if" analyses: what if we observed a steeper fitness decline? A different population size? A longer time series? Each scenario gets an instant answer.

1. NPE Posterior (Neural Posterior Estimation)

NPE uses a masked autoregressive flow (MAF) — a type of normalizing flow neural network. A normalizing flow learns an invertible transformation that warps a simple probability distribution (like a multidimensional bell curve) into the complex posterior distribution over model parameters. The "masked autoregressive" architecture processes parameters in sequence, each conditioned on the previous ones, which makes the transformation efficiently invertible.

After training on 10,000 simulated populations, the MAF has learned what fitness dynamics look like across the full range of plausible mutation rates, DFE shapes, and beneficial mutation fractions. Given our observed data, it can instantly tell us which parameter combinations are most likely.

sbijax NPE posterior distributions
Figure 1: Marginal posterior distributions from NPE, trained on 10,000 simulated populations. Each panel shows the inferred distribution for one parameter of the Basener-Sanford model. Red dashed lines mark the true parameter values used to generate the synthetic observed data. The posterior was sampled in milliseconds after a one-time training cost — the same trained network could produce posteriors for any other observed fitness trajectory without retraining.

2. FMPE Posterior (Flow-Matching Posterior Estimation)

FMPE uses a different training approach called flow matching, based on optimal transport theory. Where NPE's normalizing flow learns a fixed sequence of transformations (like folding a flat sheet into origami through prescribed steps), FMPE learns a continuous flow — a smooth path that gradually transports probability mass from a simple distribution to the posterior.

Think of optimal transport as finding the most efficient way to rearrange sand from one pile configuration into another. FMPE finds the most efficient way to transform a simple bell curve into the complex posterior shape, which often leads to more stable training and smoother posterior approximations.

Practical advantages over NPE:

sbijax FMPE posterior distributions
Figure 2: Marginal posterior distributions from FMPE, trained on 10,000 simulated populations. Red dashed lines show true parameter values. Compare these posteriors with the NPE results (Figure 1) above: FMPE is the recommended method for future extensions of the model, as its flow-matching training approach scales better to higher-dimensional parameter spaces.

Comparison with ABC-SMC and BSL

How do these neural methods compare with the approaches in notebooks 02 and 03?

Property ABC-SMC (nb 02) BSL (nb 03) NPE / FMPE (this nb)
Simulations during inference ~100,000+ ~1,000 per evaluation 0 (pre-trained)
Training cost None None 10,000 sims (one-time)
Time per new dataset Hours Hours Milliseconds
Scalability Poor (>5 params) Moderate Good (especially FMPE)
Assumptions Minimal Gaussian summary stats Network capacity sufficient

The formal comparison of posterior distributions across all methods is in notebook 05. The amortized methods here enable the boundary analysis in notebook 06, which maps the selection/meltdown boundary across parameter space — a computation that would be impractical with ABC-SMC or BSL.