AE-OT
研究背景#
现有的主流生成模型虽能生成视觉真实的图像,但存在明显缺陷:
- GAN:训练对于超参数敏感,容易出现模式崩溃,即生成器仅学习数据分布中少数模式,遗漏其他模式
- VAE:虽能捕捉所有模式,但在多模态真是数据上常生成模糊、虚假的图像,出现模式混合问题
其根本原因在于:Encoder-Decoder需学习从白噪声分布到数据分布的传输映射,而数据分布多模态或者支持域非凸时,该传输映射不连续;但深度神经网络仅能表示连续映射,这种内在冲突导致了模式崩溃和混合。
AE-OT模型将流形嵌入与概率分布传输分离,分别通过自编码器(AE)和扩展半离散OT实现,避免DNN直接学习不连续映射:
[模块1]流形嵌入:自编码器#
- 编码器$f_\theta$:将图像从图像空间映射到latent空间,将数据分布转化为latent码分布
- 解码器$g_\xi$:将latent码映射回数据流形,实现图像重构
[模块2]分布传输:最优运输#
- 计算Brenier势:给予凸优化,以训练样本的latent码为离散目标集合,通过最小化传输成本找到Brenier势,其梯度即为半离散OT映射
- 分段线性扩展 PL:将半离散OT映射扩展为全局连续映射$\tilde T$,通过latent码的$\mu$-质量中心进行三角剖分,生成新的latent码,确保不丢失任何模式
- 奇点集检测与规避:根据Figalli理论,计算支撑平面间的二面角,若角大于阈值则判定为奇点集,生成时丢入该区域的噪声样本,避免模式混合
代码阅读与重构#
Solver: SemiDiscreteSolver#
SemiDiscreteSolver 旨在估计半离散最优传输(semi-discrete OT)中每个目标点对应的体积或质量$\mu(C_i(h))$:用Sobol序在单位立方体均匀采样大量点$x$,对每个点计算$\arg\max_i\langle y_i,x\rangle+h_i$,统计每个$i$被选中的频率,得到经验估计的每个Laguerre Cell的质量。
init#
class SemiDiscreteOTSolver:
def __init__(self, h_P, dim, bat_size_P, bat_size_n, device):
self.h_P = h_P.to(device)
self.dim = dim
self.bat_size_P = bat_size_P
self.bat_size_n = bat_size_n
self.num_P = h_P.shape[0]
self.device = device
self.qrng = torch.quasirandom.SobolEngine(dimension=dim)
self._allocate_buffers()| 变量 | 含义 |
|---|---|
h_P | 目标点的位置数组,即$y_i$ |
dim | 空间维度 |
batch_size_P | 对P做批处理的批大小 |
batch_size_mu | 对连续分布$\mu$采样的样本数量,每次用$n$个样本来估计$\mu(C_i)$的分配情况 |
qrng | Sobol低差异序列生成器(quasi-random),用于在单位立方体上生成均匀样本点,比纯随机MC在低维时方差更小,收敛更快 |
_allocate_buffers()#
用于预分配中间变量,节省重复分配来带的开销。
def _allocate_buffers(self):
n,d=self.batch_size_n,self.dim
p=self.num_P
self.volP=torch.empty((n,d),device=self.device)
self.g=torch.zeros(p,device=self.device)
self.U=torch.empty((self.batch_size_P,n),device=self.device)
self.temp_P=torch.empty((self.batch_size_P,d),device=self.device)
self.temp_h=torch.empty(self.batch_size_P,device=self.device)
self.ind=torch.empty(n,dtype=torch.long,device=self.device)
self.ind_val=torch.empty(n,device=self.device)
self.tot_ind=torch.empty(n,dtype=torch.long,device=self.device)
self.tot_ind_val=torch.empty(n,device=self.device)
self.tot_ind_val_argmax=torch.empty(n,dtype=torch.long,device=self.device)
return| 变量 | 含义 | 形状 |
|---|---|---|
volP | 存储采样得到的$n$个点,形状为(n,d),这些点代表从$\mu$采样的样本 | |
g | 保存每个目标点对应的$\hat \mu(C_i)$,即经验估计的质量/体积分数 | |
U | 保存分数$\langle y_i,x_j\rangle+h_i$的值 | (batch_size_P,n) |
temp_P | 临时拷贝的一小批目标点 | (batch_size_P,d) |
temp_h | 对应temp_P的高度参数h切片 | (batch_size_P,) |
ind,ind_val | 在当前P批次上,对每个样本x_j,ind存放当前批中到达最大值的索引,ind_val存放对应的最大值 | |
tot_ind | 全局/跨批次对每个样本最终选择的目标点索引 | |
tot_ind_val | 与tot_val对应的最大分数 | |
total_ind_val_argmax | 用于记录在两两比较中选择哪一方mask,用于后续索引重建 |
sample_mu()#
def sample_mu(self):
self.qrng.draw(self.batch_size_n,out=self.volP)
self.volP.add_(-.5)
return self.volP生成采样点集合$[x_j]_{j=1}^{n}\sim [-0.5,0.5]^d$
compute_measure()#
给定当前高度参数向量$h$,估计每个Languerre cell的$\mu$质量。
def compute_measure(self,h):
r'''Estimate push-forward measure of mu under current h'''
p,n=self.num_P,self.batch_size_n
self.tot_ind_val.fill_(-1e30)
self.tot_ind.fill_(-1)
for i in range(p//self.batch_size_P):
p_batch=self.h_P[i*self.batch_size_P:(i+1)*self.batch_size_P]
self.temp_P=h[i*self.batch_size_P:(i+1)*self.batch_size_P]
self.temp_P.copy_(p_batch)
# compute scores: <P,x>+h
torch.mm(self.temp_P,self.volP.t(),out=self.U)
self.U.add_(self.temp_h[:,None])
# find max indices
torch.max(self.U,0,out=(self.ind_val,self.ind))
self.ind.add_(i*self.batch_size_P)
# aggregate maxima
torch.max(
torch.stack((self.tot_ind_val,self.ind_val)),
0,
out=(self.tot_ind_val,self.tot_ind_val_argmax)
)
self.tot_ind=torch.stack((self.tot_ind,self.ind))[self.tot_ind_val_argmax,torch.arange(n)]
self.g.zero_()
self.g.scatter_add_(0,self.tot_ind,
torch.ones_like(self.tot_ind,dtype=torch.float32))
self.g.div_(n)
return self.g关键行解读如下:
torch.mm(self.temp_P, self.volP.t(), out=self.U)
self.U.add_(self.temp_h[:, None])对当前批次所有的$y_i$和全部样本点$x_j$,计算内积矩阵:$$U_{ij}=\langle y_i,x_j\rangle$$返回形状为(batch_size_P,n)。将对应的h_i加到每一行,变为:$$U_{ij}\leftarrow \langle y_i,x_j\rangle+h_i$$,即实现势函数$$\phi(x)=\max_i(\langle y_i,x_j\rangle+h_i)$$的内部项。
# perserve batch max
torch.max(self.U, 0, out=(self.ind_val, self.ind))
self.ind.add_(i * self.bat_size_P)
# perserve global max
torch.max(
torch.stack((self.tot_ind_val, self.ind_val)),
0, out=(self.tot_ind_val, self.tot_ind_val_argmax)
)
self.tot_ind = torch.stack((self.tot_ind, self.ind))[
self.tot_ind_val_argmax, torch.arange(n)
]保存当前batch以及global下的最大值,及其索引。
self.g.zero_()
self.g.scatter_add_(0, self.tot_ind, torch.ones_like(self.tot_ind, dtype=torch.float))
self.g.div_(n)
return self.g将计数除以样本总数,转成比例,获得$\hat\mu(C_i(h))$的比例。
Opt:Adam#
代码中实现了一个“不完全Adam”,省略了bias correction项。详细Adam见:Building-Blocks/Adam-Optimization
class AdamOptimizer:
def __init__(self,size,lr,device):
self.lr=lr
self.m=torch.zeros(size,device=device)
self.v=torch.zeros(size,device=device)
self.beta1=.9
self.beta2=.999
self.eps=1e-8
def step(self,grad,param):
self.m=self.beta1*self.m+(1-self.beta1)*grad
self.v=self.beta2*self.v+(1-self.beta2)*(grad*grad)
step=-self.lr*self.m/(torch.sqrt(self.v)+self.eps)
param.add_(step)
return param变量解读#
| 变量名 | 类型 / 含义 | 数学符号 | 说明 |
|---|---|---|---|
size | 参数的形状(例如一个向量、矩阵等) | — | Adam 需要和参数、梯度形状一致的动量变量 |
lr | 学习率(learning rate) | $\alpha$ | 控制每步更新幅度 |
m | 一阶动量向量 | $m_t$ | 梯度的指数加权平均(类似“速度”) |
v | 二阶动量向量 | $v_t$ | 梯度平方的指数加权平均(类似“方差”) |
beta1 | 一阶动量衰减系数 | $\beta_1$ | 通常取 0.9 |
beta2 | 二阶动量衰减系数 | $\beta_2$ | 通常取 0.999 |
eps | 小常数 | $\epsilon$ | 防止除零(数值稳定性) |
更新一阶动量m#
将当前的梯度grad与过去动量m做指数平均,使梯度更具惯性,平滑震荡。
$$m_t=\beta_1 m_{t-1}+(1-\beta_1) g_t$$
self.m = self.beta1 * self.m + (1 - self.beta1) * grad更新二阶动量v#
记录梯度平方的指数平均,反映各参数维度上梯度方差,旨在调整不同参数维度的步长(adaptive learning rate)。
$$v_t=\beta_2 v_{t-1}+(1-\beta_2)g_t^2$$
self.v = self.beta2 * self.v + (1 - self.beta2) * (grad * grad)