A few days ago, I participated in an interview and the interviewer asked me to run the Segment Anything Simple Web demo as a test. I didn't know until I ran it that the documentation provided by Facebook was so poor. I am writing this article to document my experience.
First, the project requirements are python>=3.8, pytorch>=1.7, and torchvision>=0.8. torchvision should be installed by default when installing pytorch.
Next, clone the project to your local machine:
pip install git+https://github.com/facebookresearch/segment-anything.git
Install yarn in the segment-anything\demo directory:
npm install --g yarn
Install other dependencies:
pip install opencv-python pycocotools matplotlib onnxruntime onnx
If you encounter any errors during the subsequent execution, such as missing libraries (I remember there was one called "sam"), I suggest asking chatgpt directly for the required installations.
Next, download the Model Checkpoints. The default is vit_h.
Then, the Facebook documentation directly instructs you to export the image embedding, but there is actually a problem with this step. I recommend exporting the ONNX model first because the lightweight mask decoder of SAM can be exported as an ONNX format, which can be run in any environment that supports ONNX runtime. Replace the --checkpoint, --model-type, and --output with the ones you downloaded:
python scripts/export_onnx_model.py --checkpoint <path/to/checkpoint> --model-type <model_type> --output <path/to/output>
After running this command, a model file named "sam_onnx_quantized_example.onnx" will be generated. Place this file in the segment-anything\demo\model directory.
Next is the process of generating the mask. The Facebook documentation skips the import step, but it is included in my notes. It is recommended to open a Python runtime in the directory of your Model Checkpoints file and import the necessary functions:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
The following code in the documentation has an indentation error. It is used to display masks, points, and boxes in a matplotlib image. It seems unnecessary if you are only running the web demo:
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
Load the SAM model and preprocessor:
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
Here are the two errors in the documentation:
image = cv2.imread('src/assets/dogs.jpg')
predictor.set_image(image)
image_embedding = predictor.get_image_embedding().cpu().numpy()
np.save("dogs_embedding.npy", image_embedding)
The first line of code above means to read the dogs.jpg from src/assets/, but in reality, the image is not in src/assets/. It is actually in src/assets/data/. You should copy that image to the src/assets/ directory.
The second issue is that the Facebook documentation says to "Save the new image and embedding in src/assets/data", which means it should save the processed image and mask in src/assets/data. However, it does not actually save the processed image, and the .npy file for the mask is saved in the directory where you run the Python script.
This leads to a bug: SAM accepts a maximum image size of 1024 pixels, and during preprocessing, the image is resized and stretched to some extent. The specific processing is done in the ResizeLongestSide function called in the init of SamPredictor. In the browser version, it is handled by helpers/scaleHelper.tsx, which rounds the values. The original image and .npy file should be placed in src/assets/data, but if you open the image afterwards, you will notice that the mask is slightly offset.
One solution is to add code to save the image in segment_anything/utils/transforms.py. The simplest solution is to use a low-resolution image, which I have tested and found to work perfectly. After running, it will generate the .npy file in the directory where you run the Python script. If you don't want to do this, simply move the image and .npy file to the src/assets/data directory.
Finally, configure the file paths in App.tsx:
const IMAGE_PATH = "/assets/data/dogs.jpg";
const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy";
const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx";
Start the application:
yarn && yarn start
Open the browser and go to http://localhost:8081/.
Congratulations, you're done!
In addition, while researching how to enable multi-threading in the demo, I also learned about cross-origin issues.
Reasons for cross-origin issues:
- Different domains: For example, making an Ajax request from a page at http://www.example.com to http://api.example.com is considered cross-origin because the domains are different.
- Different protocols: For example, making an Ajax request from a page at https://www.example.com to http://www.example.com is considered cross-origin because the protocols (HTTP and HTTPS) are different.
- Different ports: For example, making an Ajax request from a page at http://www.example.com:8080 to http://www.example.com:3000 is considered cross-origin because the ports are different.
Methods to handle cross-origin issues:
- JSONP (JSON with Padding): Wrap the data as a parameter of a JavaScript callback function and request the server using the GET method by dynamically creating a script tag.
- CORS (Cross-Origin Resource Sharing): Set the appropriate response headers on the server to allow cross-origin requests. This is done by setting Access-Control-Allow-Origin and other related headers to specify which domains are allowed.
- Proxy server: Establish a proxy server under the same domain to send client requests to the target server and return the response to the client.
- WebSocket: Use the WebSocket protocol for bidirectional communication, which is not subject to the same-origin policy.
In the demo, a "Cross-Origin Isolation State" is created to enable the use of SharedArrayBuffer across origins. This is achieved by setting a specific set of HTTP response headers to indicate to the browser that the webpage should be treated as a separate origin.