SAM 2图像分割实战:从环境搭建到跑通第一个AI示例(含改进版代码)

张开发
2026/5/3 23:29:59 15 分钟阅读
SAM 2图像分割实战:从环境搭建到跑通第一个AI示例(含改进版代码)
SAM 2图像分割实战从环境搭建到跑通第一个AI示例含改进版代码在计算机视觉领域图像分割一直是核心技术之一。最近Meta推出的Segment Anything Model 2SAM 2凭借其出色的泛化能力和易用性迅速成为开发者关注的焦点。本文将带你从零开始快速搭建SAM 2开发环境并通过改进后的代码示例实现第一个有效的图像分割应用。1. 环境准备与安装在开始之前我们需要确保系统满足以下基本要求操作系统Windows 10/11或Linux推荐Ubuntu 20.04GPUNVIDIA显卡建议RTX 3060及以上显存≥8GBCUDA11.7或12.1需与PyTorch版本匹配Python3.10或更高版本推荐使用conda创建独立环境避免与其他项目产生依赖冲突conda create -n sam2 python3.10 -y conda activate sam2接下来安装PyTorch注意选择与CUDA版本匹配的安装命令# CUDA 11.7 pip install torch2.5.1 torchvision0.15.2 torchaudio2.5.1 --index-url https://download.pytorch.org/whl/cu117 # CUDA 12.1 pip install torch2.5.1 torchvision0.15.2 torchaudio2.5.1 --index-url https://download.pytorch.org/whl/cu121安装SAM 2及其依赖git clone https://github.com/facebookresearch/sam2.git cd sam2 pip install -e .注意如果遇到CUDA扩展编译失败警告可以暂时忽略。大多数基础功能仍可正常使用。2. 模型下载与初始化SAM 2提供了多种预训练模型根据硬件条件选择合适的版本模型名称参数量显存需求适用场景sam2.1_hiera_tiny50M4GB快速验证/移动端sam2.1_hiera_base150M6GB平衡性能与速度sam2.1_hiera_large500M10GB高精度需求下载大型模型推荐wget https://dl.fbaipublicfiles.com/sam2/sam2.1_hiera_large.pt -P ./checkpoints初始化预测器的Python代码import torch from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor checkpoint ./checkpoints/sam2.1_hiera_large.pt model_cfg configs/sam2.1/sam2.1_hiera_l.yaml # 初始化模型 predictor SAM2ImagePredictor( build_sam2(model_cfg, checkpoint).to(cuda) )3. 改进版图像分割实战官方示例代码往往过于简化实际使用时需要调整。以下是经过优化的完整流程from PIL import Image import numpy as np import matplotlib.pyplot as plt def show_mask(mask, ax, random_colorFalse): if random_color: color np.concatenate([np.random.random(3), np.array([0.6])], axis0) else: color np.array([30/255, 144/255, 255/255, 0.6]) h, w mask.shape[-2:] mask_image mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) # 加载测试图像 image_path test_image.jpg image Image.open(image_path).convert(RGB) # 设置交互点格式[x,y] input_points np.array([[500, 375]]) # 图像中心区域 input_labels np.array([1]) # 1表示前景点 with torch.inference_mode(), torch.autocast(cuda, dtypetorch.bfloat16): predictor.set_image(np.array(image)) masks, scores, _ predictor.predict( point_coordsinput_points, point_labelsinput_labels, multimask_outputTrue # 输出多个候选mask ) # 可视化结果 plt.figure(figsize(15, 10)) plt.imshow(image) for i, (mask, score) in enumerate(zip(masks, scores)): show_mask(mask, plt.gca(), random_colorTrue) plt.title(fMask {i1}, Score: {score:.3f}, fontsize18) plt.axis(off) plt.show()关键改进点多mask输出设置multimask_outputTrue获取多个分割方案可视化增强为不同mask添加随机颜色和置信度评分类型转换确保图像为RGB格式避免通道问题4. 高级技巧与问题排查4.1 交互式分割优化SAM 2支持多种提示方式组合使用# 组合使用点和框提示 input_box np.array([100, 100, 400, 400]) # [x1,y1,x2,y2] masks, _, _ predictor.predict( point_coordsinput_points, point_labelsinput_labels, boxinput_box, multimask_outputFalse )4.2 常见错误处理错误现象可能原因解决方案CUDA out of memory显存不足换用更小模型或减小输入图像尺寸分割结果不准确提示点位置不当尝试在目标物体不同位置添加多个点预测速度慢未启用混合精度确保使用torch.autocast(cuda)4.3 性能优化技巧# 图像预处理优化 def resize_long_edge(image, max_size1024): width, height image.size if max(width, height) max_size: scale max_size / max(width, height) new_size (int(width*scale), int(height*scale)) return image.resize(new_size, Image.BILINEAR) return image optimized_image resize_long_edge(image) predictor.set_image(np.array(optimized_image))5. 自定义数据集测试要测试自己的图片只需修改图像路径并调整提示点custom_image Image.open(your_image.jpg).convert(RGB) # 通过matplotlib交互获取坐标 plt.imshow(custom_image) points plt.ginput(n-1, timeout-1) # 点击获取点坐标 plt.close() input_points np.array(points) input_labels np.array([1]*len(points)) # 全部设为前景点 with torch.inference_mode(): predictor.set_image(np.array(custom_image)) masks, _, _ predictor.predict( point_coordsinput_points, point_labelsinput_labels, multimask_outputTrue )实际项目中我发现对复杂场景添加3-5个分布均匀的点通常能得到最佳分割效果。对于细长物体如电线沿物体走向均匀布点比集中布点效果更好。

更多文章