告别体素和卷积:用PyTorch从零实现Point Transformer处理3D点云(附代码)

张开发
2026/4/21 6:42:25 15 分钟阅读

分享文章

告别体素和卷积:用PyTorch从零实现Point Transformer处理3D点云(附代码)
从零实现Point TransformerPyTorch实战3D点云处理1. 3D点云处理的范式转变在计算机视觉领域3D点云处理正经历着从传统卷积到自注意力机制的范式转变。点云数据本质上是无序的点集合每个点包含三维坐标和可能的附加特征如颜色、法向量等。这种数据结构与规则的2D像素网格截然不同使得传统CNN难以直接应用。传统点云处理方法主要分为三类体素化方法将点云转换为规则3D网格优势可以直接应用3D卷积劣势计算和内存开销大量化误差明显基于图的方法将点云建模为图结构优势能捕捉点之间的空间关系劣势图构建计算复杂难以处理大规模场景点基方法直接处理原始点云代表工作PointNet系列优势保留原始几何信息劣势局部特征提取能力有限# 点云数据的基本表示 import torch # 典型的点云数据结构[B, N, 3C] # B: batch大小, N: 点数, 3: 坐标xyz, C: 额外特征 points torch.randn(16, 1024, 6) # 假设每个点有xyz坐标和rgb颜色Point Transformer的创新之处在于将自注意力机制与点云特性完美结合置换不变性与点云的无序性天然契合基数不变性可处理不同数量的输入点局部注意力在k近邻范围内计算效率高2. Point Transformer核心架构解析2.1 向量注意力机制Point Transformer采用向量注意力而非传统的标量注意力这是其性能优越的关键。向量注意力允许对不同特征通道进行独立调制表达能力更强。数学表达式y_i Σ_{x_j∈N(i)} [γ(φ(x_i) - ψ(x_j) δ) ⊙ (α(x_j) δ)]其中φ, ψ, α: 特征变换MLPγ: 注意力生成MLPδ: 位置编码⊙: 逐元素乘法class VectorAttention(nn.Module): def __init__(self, dim): super().__init__() self.to_qkv nn.Linear(dim, dim*3) # 生成Q,K,V self.gamma nn.Sequential( nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, dim) ) self.pos_enc nn.Sequential( nn.Linear(3, dim), nn.ReLU(), nn.Linear(dim, dim) ) def forward(self, x, pos): q, k, v self.to_qkv(x).chunk(3, dim-1) pos_enc self.pos_enc(pos) # 相对位置编码 # 计算注意力权重 attn self.gamma(q[:,None] - k pos_enc) attn torch.sigmoid(attn) # 逐通道归一化 # 特征聚合 out attn * (v pos_enc) return out.sum(dim1)2.2 位置编码设计位置信息对点云处理至关重要。Point Transformer采用可学习的相对位置编码δ θ(p_i - p_j)其中θ是包含两个线性层和一个ReLU的MLP。这种设计保持平移不变性能捕捉精细的几何关系通过端到端学习适应具体任务2.3 局部邻域构建全局注意力在点云上计算复杂度为O(N²)难以扩展。Point Transformer采用k近邻(k16)构建局部邻域将复杂度降至O(N×k)。def knn_query(points, k16): 快速k近邻查询 dist torch.cdist(points, points) _, indices torch.topk(dist, k, largestFalse) return indices3. 完整网络实现3.1 基础模块class PointTransformerBlock(nn.Module): def __init__(self, dim): super().__init__() self.attn VectorAttention(dim) self.mlp nn.Sequential( nn.Linear(dim, dim*2), nn.ReLU(), nn.Linear(dim*2, dim) ) self.norm1 nn.LayerNorm(dim) self.norm2 nn.LayerNorm(dim) def forward(self, x, pos): # 残差连接层归一化 x self.norm1(x self.attn(x, pos)) x self.norm2(x self.mlp(x)) return x3.2 下采样模块class TransitionDown(nn.Module): def __init__(self, in_dim, out_dim, k16): super().__init__() self.k k self.mlp nn.Sequential( nn.Linear(in_dim, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU() ) def forward(self, x, pos): # FPS采样 idx farthest_point_sample(pos, x.shape[1]//4) new_pos torch.gather(pos, 1, idx.unsqueeze(-1).expand(-1,-1,3)) # KNN分组 knn_idx knn_query(pos, self.k) group_idx knn_idx.gather(1, idx.unsqueeze(-1).expand(-1,-1,self.k)) # 特征聚合 grouped_features gather_neighbors(x, group_idx) grouped_features self.mlp(grouped_features.view(-1, x.size(-1))) new_features grouped_features.view(x.size(0), idx.size(1), self.k, -1).max(dim2)[0] return new_features, new_pos3.3 上采样模块class TransitionUp(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.mlp nn.Sequential( nn.Linear(in_dim, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU() ) def forward(self, x, skip_x, pos, skip_pos): # 三线性插值 interp_feats trilinear_interpolation(x, pos, skip_pos) out self.mlp(torch.cat([interp_feats, skip_x], dim-1)) return out4. 实战训练技巧4.1 数据增强策略class PointCloudAugment: def __init__(self): self.rot_range (-45, 45) self.scale_range (0.8, 1.2) def __call__(self, points): # 随机旋转 angles np.random.uniform(*self.rot_range, size3) points rotate_point_cloud(points, angles) # 随机缩放 scale np.random.uniform(*self.scale_range) points[:,:3] * scale # 随机平移 translation np.random.normal(0, 0.02, size3) points[:,:3] translation # 随机丢弃点 if np.random.rand() 0.7: mask np.random.choice([True, False], sizelen(points)) points points[mask] return points4.2 学习率调度def get_lr_scheduler(optimizer, total_epochs): def lr_lambda(epoch): if epoch total_epochs * 0.6: return 1.0 elif epoch total_epochs * 0.8: return 0.1 else: return 0.01 return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)4.3 混合精度训练scaler torch.cuda.amp.GradScaler() for epoch in range(epochs): for points, labels in train_loader: optimizer.zero_grad() with torch.cuda.amp.autocast(): preds model(points) loss criterion(preds, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()5. 性能优化与部署5.1 内存优化技巧# 使用梯度检查点 from torch.utils.checkpoint import checkpoint class MemoryEfficientBlock(nn.Module): def forward(self, x, pos): def create_custom_forward(module): def custom_forward(*inputs): return module(inputs[0], inputs[1]) return custom_forward return checkpoint(create_custom_forward(self.attn), x, pos)5.2 ONNX导出torch.onnx.export( model, (dummy_input, dummy_pos), point_transformer.onnx, opset_version12, input_names[points, positions], output_names[output], dynamic_axes{ points: {0: batch, 1: num_points}, positions: {0: batch, 1: num_points}, output: {0: batch, 1: num_points} } )5.3 TensorRT加速# 构建TensorRT引擎 builder trt.Builder(logger) network builder.create_network(1 int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) parser trt.OnnxParser(network, logger) with open(point_transformer.onnx, rb) as f: parser.parse(f.read()) config builder.create_builder_config() config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 30) engine builder.build_engine(network, config)6. 应用案例与扩展6.1 语义分割实践class PointTransformerSeg(nn.Module): def __init__(self, num_classes): super().__init__() # 编码器 self.enc1 PointTransformerBlock(64) self.down1 TransitionDown(3, 64) # 解码器 self.up1 TransitionUp(256, 128) self.dec1 PointTransformerBlock(128) # 分割头 self.seg_head nn.Sequential( nn.Linear(128, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Linear(128, num_classes) ) def forward(self, pos): # 编码路径 x1, p1 self.down1(pos, pos) x1 self.enc1(x1, p1) # 解码路径 x self.up1(x1, None, p1, pos) x self.dec1(x, pos) return self.seg_head(x)6.2 扩展到点云补全class PointCompletionNetwork(nn.Module): def __init__(self): super().__init__() self.encoder PointTransformerEncoder() self.decoder PointTransformerDecoder() self.folding_net FoldingNet() def forward(self, partial_pc): latent self.encoder(partial_pc) coarse self.decoder(latent) fine self.folding_net(coarse) return torch.cat([coarse, fine], dim1)在实际项目中Point Transformer展现出了优异的性能。在S3DIS数据集上相比传统方法有显著提升方法mIoU (%)参数量 (M)PointNet54.51.4KPConv67.114.9Point Transformer70.44.9实现过程中有几个关键发现局部注意力的k值在16-32之间效果最佳双重位置编码注意力分支和特征分支比单一编码效果更好向量注意力比标量注意力提升约5% mIoU

更多文章