Skip to content

Commit

Permalink
support edge sam.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhaode committed Dec 13, 2023
1 parent b75c38f commit 1d7e44a
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 25 deletions.
2 changes: 1 addition & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ add_executable(sam_demo ${SRCS})
if (MSVC)
target_link_libraries(sam_demo MNN)
else()
target_link_libraries(sam_demo MNN MNN_Express MNNOpenCV)
target_link_libraries(sam_demo MNN MNN_Express MNNOpenCV log)
endif()
4 changes: 4 additions & 0 deletions cpp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ mkdir build && cd build
cmake ..
make -j4
./sam_demo embed.mnn segment.mnn ../../resource/truck.jpg
# edge model need add `1`
./sam_demo edge_embed.mnn edge_segment.mnn ../../resource/truck.jpg 1
```
#### Windows
```bash
Expand All @@ -48,4 +50,6 @@ mkdir build && cd build
cmake -G "Ninja" ..
ninja
./sam_demo embed.mnn segment.mnn ../../resource/truck.jpg
# edge model need add `1`
./sam_demo edge_embed.mnn edge_segment.mnn ../../resource/truck.jpg 1
```
53 changes: 40 additions & 13 deletions cpp/sam_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,24 @@ using namespace MNN::CV;

int main(int argc, const char* argv[]) {
if (argc < 4) {
MNN_PRINT("Usage: ./sam_demo.out embed.mnn sam.mnn input.jpg [forwardType] [precision] [thread]\n");
MNN_PRINT("Usage: ./sam_demo.out embed.mnn sam.mnn input.jpg [is_edge] [forwardType] [precision] [thread]\n");
return 0;
}
bool is_edge = false;
int thread = 4;
int precision = 0;
int forwardType = MNN_FORWARD_CPU;
if (argc >= 5) {
forwardType = atoi(argv[4]);
is_edge = atoi(argv[4]);
}
if (argc >= 6) {
precision = atoi(argv[5]);
forwardType = atoi(argv[5]);
}
if (argc >= 7) {
thread = atoi(argv[6]);
precision = atoi(argv[6]);
}
if (argc >= 8) {
thread = atoi(argv[7]);
}
float mask_threshold = 0;
MNN::ScheduleConfig sConfig;
Expand All @@ -43,9 +47,13 @@ int main(int argc, const char* argv[]) {
}
// rtmgr->setCache(".cachefile");
std::shared_ptr<Module> embed(Module::load(std::vector<std::string>{}, std::vector<std::string>{}, argv[1], rtmgr));
std::shared_ptr<Module> sam(Module::load(
{"point_coords", "point_labels", "image_embeddings", "has_mask_input", "mask_input", "orig_im_size"},
{"iou_predictions", "low_res_masks", "masks"}, argv[2], rtmgr));
std::vector<std::string> sam_inputs = {"point_coords", "point_labels", "image_embeddings", "has_mask_input", "mask_input", "orig_im_size"};
std::vector<std::string> sam_outputs = {"iou_predictions", "low_res_masks", "masks"};
if (is_edge) {
sam_inputs = {"point_coords", "point_labels", "image_embeddings"};
sam_outputs = {"masks", "scores"};
}
std::shared_ptr<Module> sam(Module::load(sam_inputs, sam_outputs, argv[2], rtmgr));
auto image = imread(argv[3]);
// 1. preprocess
auto dims = image->getInfo()->dim;
Expand Down Expand Up @@ -92,18 +100,37 @@ int main(int argc, const char* argv[]) {
scale_points.push_back(0);
auto point_coords = build_input(scale_points, {1, 2, 2});
auto point_labels = build_input({1, -1}, {1, 2});
auto orig_im_size = build_input({static_cast<float>(origin_h), static_cast<float>(origin_w)}, {2});
auto has_mask_input = build_input({0}, {1});
std::vector<float> zeros(256*256, 0.f);
auto mask_input = build_input(zeros, {1, 1, 256, 256});
std::vector<VARP> input_vars;
if (is_edge) {
input_vars = {point_coords, point_labels, image_embedding};
} else {
auto orig_im_size = build_input({static_cast<float>(origin_h), static_cast<float>(origin_w)}, {2});
auto has_mask_input = build_input({0}, {1});
std::vector<float> zeros(256*256, 0.f);
auto mask_input = build_input(zeros, {1, 1, 256, 256});
input_vars = {point_coords, point_labels, image_embedding, has_mask_input, mask_input, orig_im_size};
}
st = std::chrono::system_clock::now();
auto output_vars = sam->onForward({point_coords, point_labels, image_embedding, has_mask_input, mask_input, orig_im_size});
auto output_vars = sam->onForward(input_vars);
et = std::chrono::system_clock::now();
duration = std::chrono::duration_cast<std::chrono::microseconds>(et - st);
printf("# 2. segment times: %f ms\n", duration.count() * 1e-3);
auto masks = _Convert(output_vars[2], NCHW);
// 4. postprocess: draw mask and point
// MobileSam has multi channel masks, get first
VARP masks;
if (is_edge) {
masks = output_vars[0];
auto dims = masks->getInfo()->dim;
int h = dims[2], w = dims[3];
masks = _Convert(masks, NC4HW4);
masks = _Resize(masks, length/w, length/h);
int sliceStartData[] = {0, 0, 0, 0}, sliceEndData[] = {-1, -1, new_h, new_w};
masks = _Slice(masks, _Const(sliceStartData, {4}, NCHW), _Const(sliceEndData, {4}, NCHW));
masks = _Resize(masks, (float)origin_w/new_w, (float)origin_h/new_h);
} else {
masks = output_vars[2];
}
masks = _Convert(masks, NCHW);
masks = _Gather(_Squeeze(masks, {0}), _Scalar<int>(0));
masks = _Greater(masks, _Scalar(mask_threshold));
masks = _Reshape(masks, {origin_h, origin_w, 1});
Expand Down
2 changes: 2 additions & 0 deletions python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ pip install MNN
## Run Demo
```
python segment_anything_example.py --embed embed.mnn --sam segment.mnn --img ../resource/truck.jpg
# edge model need add `--edge`
python segment_anything_example.py --embed edge_embed.mnn --sam edge_segment.mnn --img ../resource/truck.jpg --edge
```
40 changes: 29 additions & 11 deletions python/segment_anything_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import MNN.numpy as np
import MNN.cv as cv2

def inference(emed, sam, img, precision, backend, thread):
def inference(emed, sam, img, precision, backend, thread, is_edge):
mask_threshold = 0.0
# 0. load model
config = {}
Expand All @@ -14,9 +14,13 @@ def inference(emed, sam, img, precision, backend, thread):
config['numThread'] = thread
rt = MNN.nn.create_runtime_manager((config,))
embed = MNN.nn.load_module_from_file(emed, [], [], runtime_manager=rt)
sam = MNN.nn.load_module_from_file(sam,
['point_coords', 'point_labels', 'image_embeddings', 'has_mask_input', 'mask_input', 'orig_im_size'],
['iou_predictions', 'low_res_masks', 'masks'], runtime_manager=rt)
if is_edge:
sam_inputs = ['point_coords', 'point_labels', 'image_embeddings']
sam_outputs = ['masks', 'scores']
else:
sam_inputs = ['point_coords', 'point_labels', 'image_embeddings', 'has_mask_input', 'mask_input', 'orig_im_size']
sam_outputs = ['iou_predictions', 'low_res_masks', 'masks']
sam = MNN.nn.load_module_from_file(sam, sam_inputs, sam_outputs, runtime_manager=rt)
# 1. preprocess
image = cv2.imread(img)
origin_h, origin_w, _ = image.shape
Expand Down Expand Up @@ -46,16 +50,29 @@ def inference(emed, sam, img, precision, backend, thread):
input_label = np.array([1])
point_coords = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
point_labels = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)
orig_im_size = np.array([float(origin_h), float(origin_w)], dtype=np.float32)
mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
has_mask_input = np.zeros(1, dtype=np.float32)
if is_edge:
input_vars = [point_coords, point_labels, image_embedding]
else:
orig_im_size = np.array([float(origin_h), float(origin_w)], dtype=np.float32)
mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
has_mask_input = np.zeros(1, dtype=np.float32)
input_vars = [point_coords, point_labels, image_embedding, has_mask_input, mask_input, orig_im_size]
t1 = time.time()
output_vars = sam.onForward([point_coords, point_labels, image_embedding, has_mask_input, mask_input, orig_im_size])
output_vars = sam.onForward(input_vars)
t2 = time.time()
print('# 1. segment times: {} ms'.format((t2 - t1) * 1000))
masks = MNN.expr.convert(output_vars[2], MNN.expr.NCHW)
masks = masks.squeeze([0])[0]
# 4. postprocess: draw masks and point
if is_edge:
low_res_masks = output_vars[0]
low_res_masks = MNN.expr.convert(low_res_masks, MNN.expr.NC4HW4)
print(low_res_masks.data_format)
_, _, h, w, = low_res_masks.shape
masks = MNN.expr.resize(low_res_masks, length / w, length / h)
masks = masks[:, :, :new_h, :new_w]
masks = MNN.expr.resize(masks, origin_w / new_w, origin_h / new_h)
else:
masks = output_vars[2]
masks = MNN.expr.convert(masks, MNN.expr.NCHW).squeeze([0])[0]
masks = (masks > mask_threshold).reshape([origin_h, origin_w, 1])
color = np.array([30, 144, 255]).reshape([1, 1, -1])
image = (image + masks * color).astype(np.uint8)
Expand All @@ -71,5 +88,6 @@ def inference(emed, sam, img, precision, backend, thread):
parser.add_argument('--precision', type=str, default='normal', help='inference precision: normal, low, high, lowBF')
parser.add_argument('--backend', type=str, default='CPU', help='inference backend: CPU, OPENCL, OPENGL, NN, VULKAN, METAL, TRT, CUDA, HIAI')
parser.add_argument('--thread', type=int, default=4, help='inference using thread: int')
parser.add_argument('--edge', action='store_true', help='using edge sam model.')
args = parser.parse_args()
inference(args.embed, args.sam, args.img, args.precision, args.backend, args.thread)
inference(args.embed, args.sam, args.img, args.precision, args.backend, args.thread, args.edge)

0 comments on commit 1d7e44a

Please sign in to comment.