保姆级教程:用PyTorch从零实现SDE扩散模型(附完整代码与MNIST实战)

张开发
2026/4/21 3:14:58 15 分钟阅读

分享文章

保姆级教程:用PyTorch从零实现SDE扩散模型(附完整代码与MNIST实战)
从零构建SDE扩散模型PyTorch实战指南与MNIST生成艺术在生成式人工智能的浪潮中扩散模型以其出色的图像生成质量脱颖而出。不同于传统的GAN或VAE扩散模型通过模拟物理系统中的扩散过程来学习数据分布而基于随机微分方程SDE的扩散模型更是将这一过程推向连续化的新高度。本文将带您从PyTorch基础开始完整实现一个SDE扩散模型并在MNIST数据集上进行实战演练。1. 环境准备与理论基础1.1 核心数学概念SDE扩散模型建立在几个关键数学概念之上随机微分方程SDE描述系统在确定性漂移和随机扩散共同作用下的演化分数函数Score Function数据分布对数密度的梯度∇ₓlogp(x)福克-普朗克方程描述概率密度随时间的演化VP-SDEVariance Preserving SDE的数学表达为dx -\frac{1}{2}β(t)xdt \sqrt{β(t)}dW其中β(t)控制噪声调度W是标准布朗运动。1.2 开发环境配置推荐使用以下环境配置conda create -n sde python3.9 conda activate sde pip install torch1.13.1 torchvision0.14.1 matplotlib验证PyTorch安装import torch print(torch.__version__) # 应输出1.13.1 print(torch.cuda.is_available()) # 检查GPU可用性2. 模型架构设计2.1 VP-SDE类实现我们首先实现VP-SDE的核心计算逻辑class VPSDE: def __init__(self, beta_min0.1, beta_max20.0, T1.0): self.beta_min beta_min self.beta_max beta_max self.T T def beta(self, t): 线性噪声调度函数 return self.beta_min t * (self.beta_max - self.beta_min) def marginal_prob(self, x0, t): 计算前向过程的均值和标准差 integral_beta self.beta_min * t 0.5 * (self.beta_max - self.beta_min) * t**2 mean_coef torch.exp(-0.5 * integral_beta) std torch.sqrt(1 - torch.exp(-integral_beta)) return mean_coef * x0, std2.2 分数网络架构分数网络需要同时处理图像数据和时间信息class ScoreNet(nn.Module): def __init__(self): super().__init__() # 时间编码网络 self.time_embed nn.Sequential( nn.Linear(1, 128), nn.SiLU(), nn.Linear(128, 256) ) # 主干网络 self.conv1 nn.Conv2d(1, 64, 3, padding1) self.down1 nn.Conv2d(64, 128, 3, stride2, padding1) self.down2 nn.Conv2d(128, 256, 3, stride2, padding1) self.up1 nn.ConvTranspose2d(256, 128, 3, stride2, padding1, output_padding1) self.up2 nn.ConvTranspose2d(128, 64, 3, stride2, padding1, output_padding1) self.conv_out nn.Conv2d(64, 1, 3, padding1) self.act nn.SiLU() def forward(self, x, t): # 时间编码 t t.view(-1, 1) t_emb self.time_embed(t).view(-1, 256, 1, 1) # 下采样路径 h1 self.act(self.conv1(x)) h2 self.act(self.down1(h1)) h3 self.act(self.down2(h2)) # 加入时间信息 h3 h3 t_emb # 上采样路径 h self.act(self.up1(h3)) h self.act(self.up2(h h2)) return self.conv_out(h h1)3. 训练流程实现3.1 损失函数设计分数匹配损失的核心是预测噪声def loss_fn(model, x0, t, sde): 计算分数匹配损失 x_t_mean, std sde.marginal_prob(x0, t) noise torch.randn_like(x0) x_t x_t_mean std * noise # 网络预测的分数应与 -noise/std 接近 score model(x_t, t.view(-1, 1, 1, 1)) loss torch.mean((score * std noise)**2) return loss3.2 训练循环完整的训练过程实现def train(model, sde, train_loader, optimizer, device, epochs10): model.train() for epoch in range(epochs): total_loss 0 for x0, _ in train_loader: x0 x0.to(device) # 均匀采样时间点 t torch.rand(x0.shape[0], devicedevice) * (sde.T - 1e-5) 1e-5 # 计算损失并更新 optimizer.zero_grad() loss loss_fn(model, x0, t, sde) loss.backward() optimizer.step() total_loss loss.item() print(fEpoch {epoch1}, Loss: {total_loss/len(train_loader):.4f})4. 采样与生成4.1 反向SDE求解使用欧拉-丸山方法进行采样def generate_samples(model, sde, device, shape(16,1,28,28), steps1000): model.eval() with torch.no_grad(): # 初始化噪声 x torch.randn(shape, devicedevice) # 时间离散化 time_steps torch.linspace(sde.T, 1e-3, steps, devicedevice) dt time_steps[0] - time_steps[1] for t in time_steps: # 计算漂移项和扩散项 beta_t sde.beta(t) score model(x, t*torch.ones(shape[0],1,1,1,devicedevice)) drift -0.5 * beta_t * x - beta_t * score diffusion torch.sqrt(beta_t) # 欧拉-丸山更新 noise torch.randn_like(x) x x drift * dt diffusion * torch.sqrt(dt) * noise return x4.2 结果可视化生成样本并显示def plot_samples(samples): grid torchvision.utils.make_grid(samples, nrow4, normalizeTrue) plt.figure(figsize(8,8)) plt.imshow(grid.permute(1,2,0).cpu()) plt.axis(off) plt.show() # 生成并显示16个样本 samples generate_samples(model, sde, device) plot_samples(samples)5. 高级技巧与优化5.1 学习率调度添加学习率预热可以提高训练稳定性def get_lr_scheduler(optimizer, warmup5000): def lr_lambda(step): if step warmup: return float(step) / warmup return 1.0 return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)5.2 噪声调度优化尝试不同的β(t)调度策略调度类型公式特点线性β(t) βₘᵢₙ t(βₘₐₓ-βₘᵢₙ)简单直接余弦β(t) βₘᵢₙ 0.5(βₘₐₓ-βₘᵢₙ)(1-cos(πt))平滑过渡平方β(t) βₘᵢₙ t²(βₘₐₓ-βₘᵢₙ)后期变化快5.3 模型架构改进可以考虑以下改进方向添加注意力机制使用U-Net作为主干引入条件批归一化尝试残差连接class AttentionBlock(nn.Module): def __init__(self, channels): super().__init__() self.q nn.Conv2d(channels, channels, 1) self.k nn.Conv2d(channels, channels, 1) self.v nn.Conv2d(channels, channels, 1) self.proj nn.Conv2d(channels, channels, 1) def forward(self, x): B, C, H, W x.shape q self.q(x).view(B, C, -1) k self.k(x).view(B, C, -1) v self.v(x).view(B, C, -1) attn torch.softmax(q k.transpose(1,2) / (C**0.5), dim-1) out (attn v).view(B, C, H, W) return self.proj(out) x6. 实战MNIST生成完整的训练到生成流程def main(): device torch.device(cuda if torch.cuda.is_available() else cpu) # 数据加载 transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) train_set datasets.MNIST(./data, trainTrue, downloadTrue, transformtransform) train_loader DataLoader(train_set, batch_size128, shuffleTrue) # 初始化模型和SDE model ScoreNet().to(device) sde VPSDE(beta_min0.1, beta_max20.0, T1.0) # 优化器 optimizer torch.optim.Adam(model.parameters(), lr1e-3) scheduler get_lr_scheduler(optimizer) # 训练 for epoch in range(10): train(model, sde, train_loader, optimizer, device) scheduler.step() # 每2个epoch生成示例 if epoch % 2 0: samples generate_samples(model, sde, device) plot_samples(samples) # 最终生成 final_samples generate_samples(model, sde, device, shape(64,1,28,28)) plot_samples(final_samples) if __name__ __main__: main()7. 常见问题排查在实际实现过程中可能会遇到以下问题生成质量差检查噪声调度是否合理增加采样步数1000步以上尝试更大的网络容量训练不稳定添加梯度裁剪使用学习率预热检查损失值是否正常下降显存不足减小批大小使用混合精度训练简化网络结构提示在MNIST上良好的训练损失通常在0.01-0.05之间如果损失不下降可能需要检查模型实现是否正确。通过本教程的完整实现您应该能够生成清晰的MNIST数字。在实际项目中可以尝试将此框架扩展到更高分辨率的图像生成任务中只需相应调整网络架构和训练参数即可。

更多文章