-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.js
66 lines (55 loc) · 1.67 KB
/
app.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
// Define the HTTP server
import cors from "cors";
import express from "express";
import {
MyClassificationPipeline,
SegmentAnythingSingleton,
GroundingDinoSingleton,
} from "./pipeline.js";
const app = express();
const hostname = "0.0.0.0";
const port = 3001;
// Enable CORS for all routes
app.use(cors());
app.use(express.json());
app.get("/ping", (req, res) => {
res.json({ status: "ok" });
});
app.get("/", async (req, res) => {
console.log(`[${new Date().toISOString()}] ${req.method} ${req.url}`);
// Extract the query parameters
let { model_name, text, image_uri } = req.query;
if (!model_name) {
return res.status(400).json({
error: "model_name is required must be sam or classifier",
});
return;
}
let classifier, sam_model, sam_processor;
if (model_name === "sam") {
[sam_model, sam_processor] = await SegmentAnythingSingleton.getInstance();
}
if (model_name === "classifier") {
classifier = await MyClassificationPipeline.getInstance();
}
if (model_name === "object-detection") {
if (!Array.isArray(text)) {
if (!text.endsWith(".")) text = text + ".";
text = [text];
}
const grounding_dino = await GroundingDinoSingleton.getInstance();
const features = await grounding_dino(image_uri, text, { threshold: 0.3 });
return res.json(features);
}
console.log(text, model_name, image_uri);
let response;
if (req.path === "/classify" && text) {
response = await classifier(text);
return res.json(response);
} else {
return res.status(400).json({ error: "Bad request" });
}
});
app.listen(port, hostname, () => {
console.log(`Server running at http://${hostname}:${port}/`);
});