I'll walk you through standard flow matching step by step, building from the ground up.
1. The Core Idea
Flow matching learns to transport samples from a simple noise distribution to a complex data distribution by learning a vector field that guides this transformation.
Think of it like learning the "flow" of a river that carries particles from one location to another.
2. Mathematical Setup
Goal: Transform noise x0∼N(0,I) into data x1∼qdata
Key objects:
- Probability path: pt(x) - how the distribution changes over time t∈[0,1]
- Vector field: ut(x) - the "flow direction" at each point and time
- Flow map: ϕt(x) - where a particle starting at x ends up at time t
3. The Relationship Between These Objects
The vector field ut(x) defines how particles move:
dtdϕt(x)=ut(ϕt(x))
This flow map pushes the initial distribution p0 to create the probability path:
pt=[ϕt]∗p0
4. The Flow Matching Objective
Problem: We want to learn a neural network vt(x;θ) that approximates the true vector field ut(x).
Objective:
LFM(θ)=Et∼U[0,1],x∼pt(x)∥vt(x;θ)−ut(x)∥2
Issue: We don't know pt(x) or ut(x) - they're what we're trying to learn!
5. The Conditional Flow Matching Solution
Key insight: Build the marginal path from simpler conditional paths.
For each data point x1, define a conditional probability path:
pt(x∣x1)=N(x∣μt(x1),σt(x1)2I)
With boundary conditions:
- p0(x∣x1)=N(x∣0,I) (standard noise)
- p1(x∣x1)=N(x∣x1,σmin2I) (concentrated around data)
6. Constructing the Marginal Path
The marginal path is the mixture:
pt(x)=∫pt(x∣x1)q(x1)dx1
The marginal vector field is:
ut(x)=∫ut(x∣x1)pt(x)pt(x∣x1)q(x1)dx1
7. The Conditional Vector Field
For the Gaussian conditional path, the conditional vector field is:
ut(x∣x1)=σt(x1)σt′(x1)(x−μt(x1))+μt′(x1)
8. The Tractable Training Objective
Conditional Flow Matching (CFM):
LCFM(θ)=Et∼U[0,1],x1∼q(x1),x∼pt(x∣x1)∥vt(x;θ)−ut(x∣x1)∥2
Key theorem: ∇θLCFM=∇θLFM
So we can optimize the tractable CFM objective instead!
9. Concrete Example: Optimal Transport Path
Choose simple linear interpolation:
- μt(x1)=t⋅x1
- σt(x1)=1−(1−σmin)t
This gives:
ut(x∣x1)=1−(1−σmin)tx1−(1−σmin)x
The training objective becomes:
LCFM(θ)=Et,x1∼q(x1),x0∼N(0,I)∥vt(ψt(x0);θ)−(x1−(1−σmin)x0)∥2
where ψt(x0)=(1−(1−σmin)t)x0+tx1.
10. Training Algorithm
for batch in dataloader:
# Sample time uniformly
t = torch.rand(batch_size)
# Sample noise
x_0 = torch.randn_like(batch)
# Create noisy samples
x_t = (1 - (1-sigma_min)*t) * x_0 + t * batch
# Compute target vector field
target = batch - (1-sigma_min) * x_0
# Predict vector field
pred = model(x_t, t)
# Compute loss
loss = mse_loss(pred, target)
loss.backward()
11. Sampling/Inference
To generate samples:
- Start with noise: x0∼N(0,I)
- Solve ODE: dtdx=vt(x;θ) from t=0 to t=1
- The result x1 is your generated sample
12. Why This Works
The beauty is that:
- We never need to know the complex marginal distribution pt(x)
- We only work with simple Gaussian conditionals pt(x∣x1)
- The conditional paths automatically "average out" to create the right marginal behavior
- The learned vector field captures the optimal transport between noise and data
This gives you a principled way to learn complex generative models that can produce high-quality samples with deterministic sampling (no randomness needed at inference time).