前两天参加面试,面试官喊我运行Segment Anything Simple Web demo作为笔试题,不跑不知道,原来 fb 文档写的那么差,写这篇文章记录一下这次经历
首先项目要求是
python>=3.8 pytorch>=1.7 torchvision>=0.8
torchvision 应该是安装 pytorch 时候默认给你一起安装的
接着克隆项目到本地
pip install git+https://github.com/facebookresearch/segment-anything.git
在 segment-anything\demo 目录下安装 yarn
npm install --g yarn
还有些其他的环境依赖
pip install opencv-python pycocotools matplotlib onnxruntime onnx
如果后续运行还有报错说缺少库(印象里好像还有一个 sam),建议直接让 chatgpt 告诉你应该装啥
接下来下载 Model Checkpoints
默认的是vit_h
然后 fb 的文档直接喊你 Export the image embedding 导出图像遮罩,实际上是有问题的
我建议先导出 ONNX model,因为 SAM 的轻量级掩码解码器可以导出为 ONNX 格式,它可以在任何支持 ONNX 运行时的环境中运行,后面浏览器运行就是用到它
把下面的 --checkpoint --model-type -output 替换成你下载的就行了
python scripts/export_onnx_model.py --checkpoint <path/to/checkpoint> --model-type <model_type> --output <path/to/output>
运行完以后会产生一个叫 sam_onnx_quantized_example.onnx 的模型文件,把这个文件放到 segment-anything\demo\model 目录下
下面就是正式的生成遮罩过程了,这里 fb 的文档把导入给省了,但是笔记里是有的
首先是必要的函数导入,建议是在你 Model Checkpoints 文件的目录下打开 python 运行
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
下面这段在文档里是有缩进错误的,用于在 matplotlib 图中展示掩码、点和框,如果你单纯跑 web 似乎不需要
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
加载 sam 模型和预处理器
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
后面就是这个文档中两个错误的地方
image = cv2.imread('src/assets/dogs.jpg')
predictor.set_image(image)
image_embedding = predictor.get_image_embedding().cpu().numpy()
np.save("dogs_embedding.npy", image_embedding)
上面第一行代码意思是从 src/assets/ 读取 dogs.jpg,但是实际上 src/assets/ 里面根本没有图片,图片是在 src/assets/data 里,应该把那张图片复制到 src/assets/ 目录下
第二个问题是 fb 文档说Save the new image and embedding in src/assets/data
,会把处理完的图片和遮罩保存到 src/assets/data 里,实际上第一它根本不会保留处理过的图片,第二它遮罩的那个.npy 文件实际上保存在你运行 python 的那个目录
这就导致一个 bug,sam 接受的最长边长是 1024 像素,进行预处理的时候对图片进行一定程度的缩放拉伸,具体处理的应该是 SamPredictor 里__init__的时候会调用的 ResizeLongestSide 函数,在浏览器运行的时候是 helpers/scaleHelper.tsx 处理,前者有四舍五入后者没有,并且执行完以后并没有保存图片,把原图和 npy 放到 src/assets/data 的话,最后打开图片会发现遮罩出现一定程度的偏移
一种解决方法是你在 segment_anything/utils/transforms.py 里加一段保存图片的代码
最简单解决方法你直接拿一张小分辨率图片,我测试过是完全没问题的,结束以后会在你运行 python 的目录生成.npy 文件,如果你懒得搞,直接剪切到 src/assets/data 目录下
最后配置 App.tsx 中的文件路径
const IMAGE_PATH = "/assets/data/dogs.jpg";
const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy";
const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx";
启动
yarn && yarn start
打开浏览器的http://localhost:8081/
大功告成
额外的学习内容,在研究这个 demo 如何在浏览器中启用多线程的时候我还学到了跨域
跨域原因
域名不同:比如从http://www.example.com 的页面向 http://api.example.com 发起 Ajax 请求,因为域名不同而被认为是跨域。
协议不同:比如从https://www.example.com 的页面向 http://www.example.com 发起 Ajax 请求,因为协议不同(HTTP 和 HTTPS)而被认为是跨域。
端口号不同:比如从http://www.example.com:8080 的页面向 http://www.example.com:3000 发起 Ajax 请求,因为端口号不同而被认为是跨域。
处理方法是
JSONP(JSON with Padding):通过动态创建 script 标签,将数据包装为 JavaScript 回调函数的参数,并以 GET 方式请求服务器,服务器返回一段可执行的 JavaScript 代码。
CORS(跨域资源共享):在服务端设置相应的响应头,允许跨域请求。通过在服务器端设置 Access-Control-Allow-Origin 等相关响应头,来告诉浏览器允许哪些域名的访问。
代理服务器:在同源策略下,通过在同一域名下建立一个代理服务器,将客户端请求发送到目标服务器,并将响应返回给客户端。
WebSocket:使用 WebSocket 协议进行双向通信,WebSocket 不受同源策略的限制。
demo 里具体是创建一个 “跨源隔离状态”(Cross-Origin Isolation State)。设置一组特定的 HTTP 响应标头来指示浏览器将网页视为独立的原始来源,以启用 SharedArrayBuffer 的使用,实现跨域