Flow Matching

概览#

通过确保模型预测的向量场与描述数据点实际运动的向量场之间的动态特性保持一致,从而确保通过CNFs(Conditional Normalizing Flows)变换得到最终概率分布与期望的目标分布相一致。

符号定义#

  • $p_0$:简单先验分布
  • $p_1$:复杂终端分布
  • $f(x,t)$:速度场

训练目标#

目标:学习一个与时间$t$相关的速度场$f(x,t)$(重参数化后为$f_\theta(x,t)$),该速度场描述轨迹上每个点的瞬时速度

损失函数设定为:

$$\mathcal{L}=E_{x_0,x_1,t}\left[|| f_\theta(x(t),t)-\frac{d}{dt}x(t)||^2\right]$$

  • $x(t)=(1-t)x_0+tx_1$:噪声数据与真实数据之间的插值点
  • $\frac{d}{dt} x(t)=x_1-x_0$:真实的速度
  • $E_{x_0,x_1,t}$:在(随机噪声,真实数据)对,以及两者之间插值的时间进行期望

理论推导#

代码阅读#

参考:facebookresearch/flowmathcing

Probability Paths#

Affine Probability Paths#

Affine Probability Paths定义了一个概率路径$X_t$ $$X_t=\alpha_t X_1+\sigma_t X_0$$

式中,

  • $X_0\sim p_0$:噪声
  • $X_1\sim p_1$:真实样本
  • $\alpha_t,\sigma_t$:由scheduler控制的时间函数

对时间求导: $$\dot X_t=\frac{dX_t}{dt}=\dot \alpha_t X_1+\dot \sigma_t X_0$$

上述两式构成一个线性系统。

为了连接到diffusion-style训练,通常定义$$\epsilon:=X_0$$ 于是 $$X_t=\alpha_t X_1+\sigma_t \epsilon,\dot X_t=\dot \alpha_t X_1+\dot \sigma_t \epsilon$$

存在一些辅助函数,用于不同表示空间之间的转换,目标是使$(X_1,\epsilon,\dot X_t)$可以相互表示。

方法输入输出意义
target_to_velocity$x_1,x_t,t$$\dot X_t$从终点数据推算当前速度
epsilon_to_velocity$\epsilon,x_t,t$$\dot X_t$从噪声表示得到速度
velocity_to_target$v_t,x_t,t$$x_1$从速度反推出目标样本
target_to_epsilon$x_1,x_t,t$$\epsilon$从真实样本推断噪声
velocity_to_epsilon$\epsilon_t,x_t,t$$\epsilon$从速度推断噪声

target_to_velocity:从$(X_t,X_1)\rightarrow \dot X_t$

根据目标,我们想要将$\dot X_t$写成关于 $X_t$和$X_1$的线性组合,即 $$\dot X_t=a_t X_t+b_t X_1$$

由$$X_t=\alpha_t X_1+\sigma_t \epsilon,$$有$$\epsilon=\frac{X_t-\alpha_t X_1}{\sigma_t}$$

代入$$\dot X_t=\frac{dX_t}{dt}=\dot \alpha_t X_1+\dot \sigma_t X_0,$$有 $$\dot X_t=\dot \alpha_t X_1+\dot \sigma_t \frac{X_t-\alpha_t X_1}{\sigma_t}=\left(\dot \alpha_t -\frac{\alpha_t}{\sigma_t}\right)X_1+\frac{\dot \sigma_t}{\sigma_t}X_t$$

可得 $$a_t=\frac{\dot \sigma_t}{\sigma_t},b_t=\dot \alpha_t -\frac{\alpha_t}{\sigma_t}$$

def target_to_volecity(self,x_t:Tensor,x_1:Tensor,t:Tensor)->Tensor:
    scheduler_output=self.scheduler(t)

    alpha_t=scheduler_output.alpha_t
    sigma_t=scheduler_output.sigma_t
    d_alpha_t=scheduler_output.d_alpha_t
    d_sigma_t=scheduler_output.d_sigma_t

    a_t=d_sigma_t/sigma_t
    b_t=d_alpha_t-alpha_t/sigma_t

    return a_t*x_t+b_t*x_1 

epsilon_to_Velocity:$(X_t,\epsilon)\rightarrow \dot X_t$

根据目标,设$$\dot X_t=a_t X_t+b_t \epsilon.$$

由$X_t=\alpha_t X_1+\sigma_t \epsilon$,得$$X_1=\frac{X_t-\sigma_t \epsilon}{\alpha_t}$$

带入$\dot X_t=\dot \alpha_t X_1+\dot \sigma_t \epsilon$,得 $$\dot X_t=\dot \alpha_t\frac{X_t-\sigma_t\epsilon}{\alpha_t}+\dot \sigma_t \epsilon=\frac{\dot \alpha_t}{\alpha_t}x_t+\left(\dot \sigma_t-\frac{\dot \alpha_t\sigma_t}{\alpha_t}\right)\epsilon$$

综上,$$a_t=\frac{\dot \alpha_t}{\alpha_t},b_t=\dot \sigma_t-\frac{\dot \alpha_t\sigma_t}{\alpha_t}$$

def epsilon_to_volecity(self,epsilon:Tensor,x_t:Tensor,t:Tensor)->Tensor:
  scheduler_output=self.scheduler(t)

  alpha_t=scheduler_output.alpha_t
  sigma_t=scheduler_output.sigma_t
  d_alpha_t=scheduler_output.d_alpha_t
  d_sigma_t=scheduler_output.d_sigma_t

  a_t=d_alpha_t/alpha_t
  b_t=d_sigma_t-d_alpha_t*sigma_t/alpha_t

  return a_t*x_t+b_t*epsilon

volecity_to_target:$(X_t,\dot X_t)\rightarrow X_1$

考虑线性系统:

$$X_t=\alpha_t X_1+\sigma_t\epsilon,\dot X_t=\dot \alpha_t X_1+\dot \sigma_t\epsilon$$

消去$\epsilon$:

从第一式得到$\epsilon=(X_t-\alpha_t X_1)/\sigma_t$,带入第二式:

$$(\dot \alpha_t\sigma_t-\dot \sigma_t\alpha_t)X_1=\sigma_t \dot X_t-\dot \sigma_t X_t$$

移项:$$X_1=\frac{\sigma_t \dot X_t-\dot \sigma_t X_t}{\dot \alpha_t \sigma_t-\dot \sigma_t\alpha_t}$$

若要写成$a_t X_t+b_t\dot X_t$的形式,则有 $$a_t=-\frac{\dot\sigma_t}{\dot \alpha_t \sigma_t-\dot \sigma_t\alpha_t}$$

$$b_t=\frac{\sigma_t}{\dot \alpha_t \sigma_t-\dot \sigma_t\alpha_t}$$

def volecity_to_target(self,dx_t:Tensor,x_t:Tensor,t:Tensor)->Tensor:
  scheduler_output=self.scheduler(t)

  alpha_t=scheduler_output.alpha_t
  sigma_t=scheduler_output.sigma_t
  d_alpha_t=scheduler_output.d_alpha_t
  d_sigma_t=scheduler_output.d_sigma_t

  a_t=-d_sigma_t/(d_alpha_t*sigma_t-d_sigma_t*alpha_t)
  b_t=d_sigma_t/(d_alpha_t*sigma_t-d_sigma_t*alpha_t)
  return  a_t*x_t+b_t*dx_t

target_to_epsilon:$(X_t,X_1)\rightarrow \epsilon$

$$X_t=\alpha_t X_1+\sigma_t \epsilon\Rightarrow \epsilon=\frac{X_t-\alpha_t X_1}{\sigma_t}$$

即 $$a_t=\frac{1}{\sigma_t},b_t=-\frac{\alpha_t}{\sigma_t}$$

def target_to_epsilon(self,x_t:Tensor,x_1:Tensor,t:Tensor)->Tensor:
  scheduler_output=self.scheduler(t)

  alpha_t=scheduler_output.alpha_t
  sigma_t=scheduler_output.sigma_t

  a_t=1/sigma_t
  b_t=-alpha_t/sigma_t
  return  a_t*x_t+b_t*x_1

velocity_to_epsilon:$(X_t,\dot X_t)\rightarrow \epsilon$

同理,通过解线性系统消去$X_1$: $$(\dot \sigma_t\alpha_t -\dot\alpha_t\sigma_t)\epsilon=\dot\sigma_t X_t-\sigma_t \dot X_t$$

得到

$$\epsilon=\frac{\dot \sigma_t X_t-\sigma_t \dot X_t}{\dot\sigma_t\alpha_t-\dot \alpha_t\sigma_t}=a_tX_t+b_t\dot X_t$$

其中 $$a_t=-\frac{\dot \alpha_t}{\dot\sigma_t\alpha_t-\dot \alpha_t\sigma_t}$$

$$b_t=\frac{\alpha_t}{\dot\sigma_t\alpha_t-\dot \alpha_t\sigma_t}$$

def velocity_to_epsilon(self,x_t:Tensor,dx_t:Tensor,t:Tensor)->Tensor:
  scheduler_output=self.scheduler(t)

  alpha_t=scheduler_output.alpha_t
  sigma_t=scheduler_output.sigma_t
  d_alpha_t=scheduler_output.d_alpha_t
  d_sigma_t=scheduler_output.d_sigma_t

  a_t=-d_alpha_t/(d_sigma_t*alpha_t-d_alpha_t*sigma_t)
  b_t=alpha_t/(d_sigma_t*alpha_t-d_alpha_t*sigma_t)
  return  a_t*x_t+b_t*dx_t

Solvers#


Schedulers#

Scheduler$\alpha_t$$\sigma_t$$\alpha_t^2+\sigma_t^2=1?$特征形状
CondOT$t$$1-t$False线性过渡线性
Polynomial$t^n$$1-t^n$False凹凸控制多项式
VPScheduler$e^{-0.5T}$$\sqrt{1-e^{-0.5T}}$TrueDDPM连续化指数
LinearVPScheduler$t$$\sqrt{1-t^2}$True简化版本VP圆弧
CosineScheduler$\sin(\pi t/2)$$\cos(\pi t/2)$True平滑过渡余弦

基础定义#

每个scheduler都返回 $$(\alpha_t,\sigma_t,\frac{d}{dt}\alpha_t,\frac{d}{dt}\sigma_t)$$

式中,

  • $\alpha_t$:信号成份系数 signal multiplier
  • $\sigma_t$:噪声成份系数 noise multiplier
  • $t=0$,纯噪声阶段;$t=1$:纯信号阶段

CondOT#

定义 $$\alpha_t=t,\sigma_t=1-t$$ $$\frac{d\alpha_t}{dt}=1,\frac{d\sigma_t}{dt}=-1$$

信噪比SNR: $$SNR=\frac{\alpha_t}{\sigma_t}=\frac{t}{1-t}$$

由信噪比倒推$t$: $$t=\frac{SNR}{1+SNR}$$

class CondOTScheduler(ConvexScheduler):
  def __call__(self, t:Tensor)->SchedulerOutput:
      return SchedulerOutput(
          alpha_t=t,
          sigma_t=1-t,
          d_alpha_t=torch.ones_like(t),
          d_sigma_t=-torch.ones_like(t)
      )

  def kappa_inverse(self, kappa:Tensor)->Tensor:
      return kappa

Polynomial#

定义: $$\alpha_t=t^n,\sigma_t=1-t^n$$ $$\frac{d\alpha_t}{dt}=nt^{n-1},\frac{d\sigma_t}{dt}=-nt^{n-1}$$

特征:

  • 控制过渡的曲率
  • 当$n>1$:前期变化慢,后期变化快
  • 当$0<n<1$:前期变化快,后期变换慢

SNR逆变换: $$SNR=\frac{t^n}{1-t^n}\Rightarrow t=\left(\frac{SNR}{1+SNR}\right)^{1/n}$$

class PolynomialConvexScheduler(ConvexScheduler):
  def __init__(self,n:Union[float,int])->None:
      assert isinstance(n,(float,int))
      assert n>0
      self.n=n

  def __call__(self, t:Tensor)->SchedulerOutput:
      return SchedulerOutput(
          alpha_t=t**self.n,
          sigma_t=1-t**self.n,
          d_alpha_t=self.n*t**(self.n-1),
          d_sigma_t=-self.n*t**(self.n-1)
      )

  def kappa_inverse(self, kappa:Tensor)->Tensor:
      return torch.pow(kappa,1./self.n)

VPScheduler(Variance Preserving Scheduler)#

定义:

设参数:

  • $\beta_{\min}, \beta_{\max}$:控制噪声增长速率。

定义: $$ T(t) = \tfrac{1}{2}(1-t)^2(\beta_{\max} - \beta_{\min}) + (1-t)\beta_{\min} $$ $$ \alpha_t = e^{-0.5 T(t)},\quad \sigma_t = \sqrt{1 - e^{-T(t)}} $$ $$ \frac{dT}{dt} = - (1-t)(\beta_{\max}-\beta_{\min}) - \beta_{\min} $$ $$ \frac{d\alpha_t}{dt} = -0.5 \frac{dT}{dt} e^{-0.5T(t)},\quad \frac{d\sigma_t}{dt} = 0.5 \frac{dT}{dt} \frac{e^{-T(t)}}{\sqrt{1 - e^{-T(t)}}} $$

解释:

  • 模拟连续时间的“方差保持”扩散(variance preserving)过程;
  • 当 (t=0):(T) 最大 → 噪声大; 当 (t=1):(T=0) → 纯信号。

SNR 逆变换:

$$ \text{SNR} = \frac{\alpha_t}{\sigma_t} = \frac{e^{-0.5T}}{\sqrt{1 - e^{-T}}} \Rightarrow T = -\ln\frac{\text{SNR}^2}{1+\text{SNR}^2} $$ 根据 $T(t)$ 的定义反解 $t$。

class VPScheduler(Scheduler):
  def __init__(self,beta_min:float=0.1,beta_max:float=20.0)->None:
      self.beta_min=beta_min
      self.beta_max=beta_max
      super().__init__()
  
  def __call__(self, t:Tensor)->SchedulerOutput:
      b=self.beta_min
      B=self.beta_max
      T=(1-t)**2*(B-b)/2+(1-t)*b 
      alpha_t=torch.exp(-T/2)
      sigma_t=torch.sqrt(1-torch.exp(-T))

      d_T=-(1-t)*(B-b)-b 
      d_alpha_t=-0.5*d_T*torch.exp(-0.5*T)
      d_sigma_t=0.5*d_T*(torch.exp(-T)/torch.sqrt(1-torch.exp(-T)))
      return  SchedulerOutput(
          alpha_t=alpha_t,
          sigma_t=sigma_t,
          d_alpha_t=d_alpha_t,
          d_sigma_t=d_sigma_t
      )

  def snr_inverse(self, snr:Tensor)->Tensor:
      T=-torch.log(snr**2/(1+snr**2))
      b=self.beta_min;B=self.beta_max
      t=1 - ((-b + torch.sqrt(b**2 + 2 * (B - b) * T)) / (B - b))
      return t 

LinearVPScheduler#

定义:

$$ \alpha_t = t,\quad \sigma_t = \sqrt{1 - t^2} $$ $$ \frac{d\alpha_t}{dt} = 1,\quad \frac{d\sigma_t}{dt} = -\frac{t}{\sqrt{1 - t^2}} $$

特性:

  • 满足$\alpha_t^2 + \sigma_t^2 = 1$,真正的“variance preserving”
  • 路径是一个四分之一圆弧

SNR 逆变换:

$$ \text{SNR} = \frac{t}{\sqrt{1 - t^2}} \Rightarrow t = \frac{\text{SNR}}{\sqrt{1 + \text{SNR}^2}}$$

class LinearVPScheduler(Scheduler):
  def __call__(self, t:Tensor)->SchedulerOutput:
      return SchedulerOutput(
          alpha_t=t,
          sigma_t=torch.sqrt(1-t**2),
          d_alpha_t=1,
          d_sigma_t=-t/torch.sqrt(1-t**2)
      )
  
  def snr_inverse(self, snr:Tensor)->Tensor:
      return torch.sqrt(snr**2/(1.+snr**2))

CosineScheduler#

定义:

$$ \alpha_t = \sin\left(\frac{\pi}{2} t\right),\quad \sigma_t = \cos\left(\frac{\pi}{2} t\right) $$ $$ \frac{d\alpha_t}{dt} = \frac{\pi}{2}\cos\left(\frac{\pi}{2}t\right),\quad \frac{d\sigma_t}{dt} = -\frac{\pi}{2}\sin\left(\frac{\pi}{2}t\right) $$

特性:

  • 满足 $\alpha_t^2 + \sigma_t^2 = 1$,严格variance preserving
  • 对应“余弦时间表”,前期噪声衰减慢,后期衰减快
  • 常用于高质量 diffusion 模型(如 Imagen, EDM)

SNR 逆变换:

$$ \text{SNR} = \frac{\sin(\frac{\pi}{2}t)}{\cos(\frac{\pi}{2}t)} = \tan(\frac{\pi}{2}t) \Rightarrow t = \frac{2}{\pi}\arctan(\text{SNR}) $$

class CosineScheduler(Scheduler):
  def __call__(self, t:Tensor)->SchedulerOutput:
      return SchedulerOutput(
          alpha_t=torch.sin(torch.pi*t/2),
          sigma_t=torch.cos(torch.pi*t/2),
          d_alpha_t=torch.pi*torch.cos(torch.pi*t/2)/2,
          d_sigma_t=-torch.pi*torch.sin(torch.pi*t/2)/2
      )
  
  def snr_inverse(self, snr:Tensor)->Tensor:
      return 2.*torch.atan(snr)/torch.pi

实验结果#

参考#

[1] Flow Matching生成模型:从理论基础到Pytorch代码实现

[2] Lipman, Yaron, et al. "Flow matching for generative modeling." arXiv preprint arXiv:2210.02747 (2022).