AE-OT

github link

研究背景#

现有的主流生成模型虽能生成视觉真实的图像,但存在明显缺陷:

  • GAN:训练对于超参数敏感,容易出现模式崩溃,即生成器仅学习数据分布中少数模式,遗漏其他模式
  • VAE:虽能捕捉所有模式,但在多模态真是数据上常生成模糊、虚假的图像,出现模式混合问题

其根本原因在于:Encoder-Decoder需学习从白噪声分布到数据分布的传输映射,而数据分布多模态或者支持域非凸时,该传输映射不连续;但深度神经网络仅能表示连续映射,这种内在冲突导致了模式崩溃和混合。

AE-OT模型将流形嵌入与概率分布传输分离,分别通过自编码器(AE)和扩展半离散OT实现,避免DNN直接学习不连续映射:

[模块1]流形嵌入:自编码器#

  • 编码器$f_\theta$:将图像从图像空间映射到latent空间,将数据分布转化为latent码分布
  • 解码器$g_\xi$:将latent码映射回数据流形,实现图像重构

[模块2]分布传输:最优运输#

  • 计算Brenier势:给予凸优化,以训练样本的latent码为离散目标集合,通过最小化传输成本找到Brenier势,其梯度即为半离散OT映射
  • 分段线性扩展 PL:将半离散OT映射扩展为全局连续映射$\tilde T$,通过latent码的$\mu$-质量中心进行三角剖分,生成新的latent码,确保不丢失任何模式
  • 奇点集检测与规避:根据Figalli理论,计算支撑平面间的二面角,若角大于阈值则判定为奇点集,生成时丢入该区域的噪声样本,避免模式混合

代码阅读与重构#

Solver: SemiDiscreteSolver#

SemiDiscreteSolver 旨在估计半离散最优传输(semi-discrete OT)中每个目标点对应的体积或质量$\mu(C_i(h))$:用Sobol序在单位立方体均匀采样大量点$x$,对每个点计算$\arg\max_i\langle y_i,x\rangle+h_i$,统计每个$i$被选中的频率,得到经验估计的每个Laguerre Cell的质量。

init#

class SemiDiscreteOTSolver:
    def __init__(self, h_P, dim, bat_size_P, bat_size_n, device):
        self.h_P = h_P.to(device)
        self.dim = dim
        self.bat_size_P = bat_size_P
        self.bat_size_n = bat_size_n
        self.num_P = h_P.shape[0]
        self.device = device

        self.qrng = torch.quasirandom.SobolEngine(dimension=dim)
        self._allocate_buffers()
变量含义
h_P目标点的位置数组,即$y_i$
dim空间维度
batch_size_P对P做批处理的批大小
batch_size_mu对连续分布$\mu$采样的样本数量,每次用$n$个样本来估计$\mu(C_i)$的分配情况
qrngSobol低差异序列生成器(quasi-random),用于在单位立方体上生成均匀样本点,比纯随机MC在低维时方差更小,收敛更快

_allocate_buffers()#

用于预分配中间变量,节省重复分配来带的开销。

def _allocate_buffers(self):
    n,d=self.batch_size_n,self.dim
    p=self.num_P
    
    self.volP=torch.empty((n,d),device=self.device)
    self.g=torch.zeros(p,device=self.device)
    self.U=torch.empty((self.batch_size_P,n),device=self.device)
    self.temp_P=torch.empty((self.batch_size_P,d),device=self.device)
    self.temp_h=torch.empty(self.batch_size_P,device=self.device)
    self.ind=torch.empty(n,dtype=torch.long,device=self.device)
    self.ind_val=torch.empty(n,device=self.device)
    self.tot_ind=torch.empty(n,dtype=torch.long,device=self.device)
    self.tot_ind_val=torch.empty(n,device=self.device)
    self.tot_ind_val_argmax=torch.empty(n,dtype=torch.long,device=self.device)
    return
变量含义形状
volP存储采样得到的$n$个点,形状为(n,d),这些点代表从$\mu$采样的样本
g保存每个目标点对应的$\hat \mu(C_i)$,即经验估计的质量/体积分数
U保存分数$\langle y_i,x_j\rangle+h_i$的值(batch_size_P,n)
temp_P临时拷贝的一小批目标点(batch_size_P,d)
temp_h对应temp_P的高度参数h切片(batch_size_P,)
ind,ind_val在当前P批次上,对每个样本x_jind存放当前批中到达最大值的索引,ind_val存放对应的最大值
tot_ind全局/跨批次对每个样本最终选择的目标点索引
tot_ind_valtot_val对应的最大分数
total_ind_val_argmax用于记录在两两比较中选择哪一方mask,用于后续索引重建

sample_mu()#

def sample_mu(self):
  self.qrng.draw(self.batch_size_n,out=self.volP)
  self.volP.add_(-.5)
  return self.volP

生成采样点集合$[x_j]_{j=1}^{n}\sim [-0.5,0.5]^d$

compute_measure()#

给定当前高度参数向量$h$,估计每个Languerre cell的$\mu$质量。

  def compute_measure(self,h):
    r'''Estimate push-forward measure of mu under current h'''
    p,n=self.num_P,self.batch_size_n
    self.tot_ind_val.fill_(-1e30)
    self.tot_ind.fill_(-1)

    for i in range(p//self.batch_size_P):
        p_batch=self.h_P[i*self.batch_size_P:(i+1)*self.batch_size_P]
        self.temp_P=h[i*self.batch_size_P:(i+1)*self.batch_size_P]
        self.temp_P.copy_(p_batch)

        # compute scores: <P,x>+h
        torch.mm(self.temp_P,self.volP.t(),out=self.U)
        self.U.add_(self.temp_h[:,None])

        # find max indices
        torch.max(self.U,0,out=(self.ind_val,self.ind))
        self.ind.add_(i*self.batch_size_P)

        # aggregate maxima
        torch.max(
            torch.stack((self.tot_ind_val,self.ind_val)),
            0,
            out=(self.tot_ind_val,self.tot_ind_val_argmax)
        )
        self.tot_ind=torch.stack((self.tot_ind,self.ind))[self.tot_ind_val_argmax,torch.arange(n)]
    
    self.g.zero_()
    self.g.scatter_add_(0,self.tot_ind,
                        torch.ones_like(self.tot_ind,dtype=torch.float32))
    self.g.div_(n)
    return  self.g

关键行解读如下:

torch.mm(self.temp_P, self.volP.t(), out=self.U)
self.U.add_(self.temp_h[:, None])

对当前批次所有的$y_i$和全部样本点$x_j$,计算内积矩阵:$$U_{ij}=\langle y_i,x_j\rangle$$返回形状为(batch_size_P,n)。将对应的h_i加到每一行,变为:$$U_{ij}\leftarrow \langle y_i,x_j\rangle+h_i$$,即实现势函数$$\phi(x)=\max_i(\langle y_i,x_j\rangle+h_i)$$的内部项。

# perserve batch max
torch.max(self.U, 0, out=(self.ind_val, self.ind))
self.ind.add_(i * self.bat_size_P)

# perserve global max
torch.max(
    torch.stack((self.tot_ind_val, self.ind_val)), 
    0, out=(self.tot_ind_val, self.tot_ind_val_argmax)
)
self.tot_ind = torch.stack((self.tot_ind, self.ind))[
    self.tot_ind_val_argmax, torch.arange(n)
]

保存当前batch以及global下的最大值,及其索引。

self.g.zero_()
self.g.scatter_add_(0, self.tot_ind, torch.ones_like(self.tot_ind, dtype=torch.float))
self.g.div_(n)
return self.g

将计数除以样本总数,转成比例,获得$\hat\mu(C_i(h))$的比例。

Opt:Adam#

代码中实现了一个“不完全Adam”,省略了bias correction项。详细Adam见:Building-Blocks/Adam-Optimization

class AdamOptimizer:
    def __init__(self,size,lr,device):
        self.lr=lr
        self.m=torch.zeros(size,device=device)
        self.v=torch.zeros(size,device=device)

        self.beta1=.9
        self.beta2=.999
        self.eps=1e-8
        
    def step(self,grad,param):
        self.m=self.beta1*self.m+(1-self.beta1)*grad
        self.v=self.beta2*self.v+(1-self.beta2)*(grad*grad)
        step=-self.lr*self.m/(torch.sqrt(self.v)+self.eps)
        param.add_(step)
        return param

变量解读#

变量名类型 / 含义数学符号说明
size参数的形状(例如一个向量、矩阵等)Adam 需要和参数、梯度形状一致的动量变量
lr学习率(learning rate)$\alpha$控制每步更新幅度
m一阶动量向量$m_t$梯度的指数加权平均(类似“速度”)
v二阶动量向量$v_t$梯度平方的指数加权平均(类似“方差”)
beta1一阶动量衰减系数$\beta_1$通常取 0.9
beta2二阶动量衰减系数$\beta_2$通常取 0.999
eps小常数$\epsilon$防止除零(数值稳定性)

更新一阶动量m#

将当前的梯度grad与过去动量m做指数平均,使梯度更具惯性,平滑震荡。

$$m_t=\beta_1 m_{t-1}+(1-\beta_1) g_t$$

self.m = self.beta1 * self.m + (1 - self.beta1) * grad

更新二阶动量v#

记录梯度平方的指数平均,反映各参数维度上梯度方差,旨在调整不同参数维度的步长(adaptive learning rate)。

$$v_t=\beta_2 v_{t-1}+(1-\beta_2)g_t^2$$

self.v = self.beta2 * self.v + (1 - self.beta2) * (grad * grad)