forked from CompVis/latent-diffusion
-
Notifications
You must be signed in to change notification settings - Fork 16
/
cogflare.py
130 lines (104 loc) · 3.92 KB
/
cogflare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from pydantic import ValidationError
from fastapi.responses import JSONResponse
import websocket
import logging
import _thread
import time
import json
import rel
import io, requests
from typing import Any, Callable, Dict, List, Optional, Tuple
from pathlib import Path
from cog.predictor import get_input_type, load_config, load_predictor
from cog.response import Status
from cog.server.runner import PredictionRunner
from cog.json import upload_files
logger = logging.getLogger("cog")
def run():
runner = PredictionRunner()
InputType = get_input_type(predictor)
runner.setup()
upload_url = "https://cog.nmb.ai/v1/upload"
def upload(obj: Any) -> Any:
def upload_file(fh: io.IOBase) -> str:
resp = requests.put(upload_url, files={"file": fh})
resp.raise_for_status()
return resp.json()["url"]
return upload_files(obj, upload_file)
def on_message(ws, message):
request = json.loads(message)
try:
input_obj = InputType(**request["input"])
except ValidationError as e:
logger.error(f"Input validation failed: {e}")
return
print(input_obj)
logs: List[str] = []
response: Dict[str, Any] = {"logs": logs}
runner.run(**input_obj.dict())
while runner.is_processing() and not runner.has_output_waiting():
if runner.has_logs_waiting():
logs.extend(runner.read_logs())
ws.send(json.dumps(response))
if runner.error() is not None:
logger.error(runner.error())
response["status"] = Status.FAILED
response["error"] = e
ws.send(json.dumps(response))
return
if runner.is_output_generator():
output = response["output"] = []
while runner.is_processing():
if runner.has_output_waiting() or runner.has_logs_waiting():
new_output = [upload(o) for o in runner.read_output()]
new_logs = runner.read_logs()
if new_output == [] and new_logs == []:
continue
output.extend(new_output)
logs.extend(new_logs)
ws.send(json.dumps(response))
if runner.error() is not None:
response["status"] = Status.FAILED
response["error"] = str(runner.error)
ws.send(json.dumps(response))
return
response["status"] = Status.SUCCEEDED
output.extend(upload(o) for o in runner.read_output())
logs.extend(runner.read_logs())
ws.send(json.dumps(response))
else:
while runner.is_processing():
if runner.has_logs_waiting():
logs.extend(runner.read_logs())
ws.send(json.dumps(response))
if runner.error() is not None:
response["status"] = Status.FAILED
response["error"] = str(runner.error())
ws.send(json.dumps(response))
output = runner.read_output()
assert len(output) == 1
response["status"] = Status.SUCCEEDED
response["output"] = upload(output[0])
logs.extend(runner.read_logs())
ws.send(json.dumps(response))
def on_error(ws, error):
print(error)
def on_close(ws, close_status_code, close_msg):
print("### closed ###")
def on_open(ws):
print("Opened connection")
websocket.enableTrace(True)
ws = websocket.WebSocketApp(
"wss://cog.nmb.ai/v1/queue/TEST/websocket",
on_open=on_open,
on_message=on_message,
on_error=on_error,
on_close=on_close,
)
ws.run_forever(dispatcher=rel)
rel.signal(2, rel.abort)
rel.dispatch()
if __name__ == "__main__":
config = load_config()
predictor = load_predictor(config)
run()