Skip to content

Commit 096c8a5

Browse files
authored
Refactor mlc_chat into a formal package (octoml#266)
This PR refactors the mlc-chat into a formal package Still need some followup TODOs on cleaning up the rest and gradio API.
1 parent c409ca0 commit 096c8a5

File tree

9 files changed

+246
-49
lines changed

9 files changed

+246
-49
lines changed

CMakeLists.txt

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,11 @@ else()
114114
target_link_libraries(mlc_chat_cli PUBLIC mlc_llm)
115115
endif()
116116

117-
if (UNIX OR APPLE)
118-
add_library(mlc_llm_module MODULE $<TARGET_OBJECTS:mlc_llm_objs>)
119-
target_link_libraries(mlc_llm_module PRIVATE tokenizers_cpp)
120-
if (APPLE)
121-
set_property(TARGET mlc_llm_module APPEND PROPERTY LINK_OPTIONS -undefined dynamic_lookup)
122-
else()
123-
set_property(TARGET mlc_llm_module APPEND PROPERTY LINK_OPTIONS)
124-
endif()
125-
endif()
117+
# create a dummy libtvms.so, so mlc_llm_module can be loaded by tvm smoothly
118+
add_library(libtvm_dummy SHARED $<TARGET_OBJECTS:tvm_runtime_objs> $<TARGET_OBJECTS:tvm_libinfo_objs>)
119+
set_target_properties(libtvm_dummy PROPERTIES OUTPUT_NAME "tvm")
120+
add_library(mlc_llm_module MODULE $<TARGET_OBJECTS:mlc_llm_objs>)
121+
target_link_libraries(mlc_llm_module PRIVATE tokenizers_cpp libtvm_dummy)
126122

127123
# when this option is on,
128124
# we install all static lib deps into lib
@@ -145,7 +141,7 @@ if (MLC_LLM_INSTALL_STATIC_LIB)
145141
)
146142
endif()
147143
else()
148-
install(TARGETS mlc_chat_cli tvm_runtime mlc_llm
144+
install(TARGETS mlc_chat_cli tvm_runtime mlc_llm mlc_llm_module
149145
mlc_llm_static
150146
tokenizers_cpp
151147
sentencepiece-static

mlc_llm/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def compile_metal(src, target):
437437
"supports_int8": 1,
438438
"supports_8bit_buffer": 1,
439439
"supports_16bit_buffer": 1,
440-
"supports_storage_buffer_storage_class": 1
440+
"supports_storage_buffer_storage_class": 1,
441441
}
442442
),
443443
host="llvm",

python/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ There is currently a dependency to build from source in order to use the [REST A
1919

2020
To launch the Gradio API, locate to the root of the mlc-llm repo. The arguments you need to provide are `artifact-path` which stores your pre-built models, `device-name` (default cuda) and `device-id` (default 0). After launched, the Gradio API allows you to select different models and quantization types in its interface.
2121

22-
python3 -m python.mlc_chat.gradio --artifact-path /path/to/your/models --device-name cuda --device-id 0
22+
PYTHONPATH=python python3 -m mlc_chat.gradio --artifact-path /path/to/your/models --device-name cuda --device-id 0

python/mlc_chat/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""MLC Chat python package.
2+
3+
MLC Chat is the app runtime of MLC LLM.
4+
"""
5+
from .libinfo import __version__
6+
from .chat_module import ChatModule

python/mlc_chat/chat_module.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
"""Python runtime for MLC chat."""
2-
2+
#! pylint: disable=unused-import
3+
import os
4+
import sys
35
import ctypes
4-
56
import tvm
7+
import tvm._ffi.base
8+
from . import libinfo
9+
10+
11+
def _load_mlc_llm_lib():
12+
"""Load mlc llm lib"""
13+
if sys.platform.startswith("win32") and sys.version_info >= (3, 8):
14+
for path in libinfo.get_dll_directories():
15+
os.add_dll_directory(path)
16+
lib_name = "mlc_llm" if tvm._ffi.base._RUNTIME_ONLY else "mlc_llm_module"
17+
lib_path = libinfo.find_lib_path(lib_name, optional=False)
18+
return ctypes.CDLL(lib_path[0]), lib_path[0]
619

720

8-
def load_llm_chat(mlc_lib_path):
9-
return ctypes.CDLL(mlc_lib_path)
21+
_LIB, _LIB_PATH = _load_mlc_llm_lib()
1022

1123

1224
def supported_models():
@@ -17,9 +29,8 @@ def quantization_keys():
1729
return ["q3f16_0", "q4f16_0", "q4f32_0", "q0f32", "q0f16"]
1830

1931

20-
class LLMChatModule:
32+
class ChatModule:
2133
def __init__(self, mlc_lib_path, target="cuda", device_id=0):
22-
load_llm_chat(mlc_lib_path)
2334
fcreate = tvm.get_global_func("mlc.llm_chat_create")
2435
assert fcreate is not None
2536
if target == "cuda":

python/mlc_chat/gradio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import gradio as gr
88
import tvm
99

10-
from python.mlc_chat.chat_module import LLMChatModule
10+
from .chat_module import ChatModule
1111

1212
model_keys = ["vicuna-v1-7b"]
1313
quantization_keys = ["q3f16_0", "q4f16_0", "q4f32_0", "q0f32", "q0f16"]
@@ -23,7 +23,7 @@ def _parse_args():
2323
return parsed
2424

2525

26-
class GradioChatModule(LLMChatModule):
26+
class GradioChatModule(ChatModule):
2727
def __init__(self, ARGS):
2828
super().__init__(ARGS.mlc_lib_path, ARGS.device_name, ARGS.device_id)
2929
self.artifact_path = ARGS.artifact_path

python/mlc_chat/libinfo.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""Library information. This is a standalone file that can be used to get various info"""
2+
#! pylint: disable=protected-access
3+
import os
4+
import sys
5+
6+
__version__ = "0.1.dev0"
7+
8+
9+
def get_env_paths(env_var, splitter):
10+
"""Get path in env variable"""
11+
if os.environ.get(env_var, None):
12+
return [p.strip() for p in os.environ[env_var].split(splitter)]
13+
return []
14+
15+
16+
def get_dll_directories():
17+
"""Get extra mlc llm dll directories"""
18+
curr_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
19+
source_dir = os.path.abspath(os.path.join(curr_dir, "..", ".."))
20+
dll_path = [
21+
curr_dir,
22+
os.path.join(source_dir, "build"),
23+
os.path.join(source_dir, "build", "Release"),
24+
]
25+
26+
if "MLC_LIBRARY_PATH" in os.environ:
27+
dll_path.append(os.environ["MLC_LIBRARY_PATH"])
28+
29+
if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"):
30+
dll_path.extend(get_env_paths("LD_LIBRARY_PATH", ":"))
31+
elif sys.platform.startswith("darwin"):
32+
dll_path.extend(get_env_paths("DYLD_LIBRARY_PATH", ":"))
33+
elif sys.platform.startswith("win32"):
34+
dll_path.extend(get_env_paths("PATH", ";"))
35+
36+
return dll_path
37+
38+
39+
def find_lib_path(name, optional=False):
40+
"""Find mlc llm library
41+
42+
Parameters
43+
----------
44+
name : str
45+
The name of the library
46+
47+
optional: boolean
48+
Whether the library is required
49+
"""
50+
if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"):
51+
lib_name = f"lib{name}.so"
52+
elif sys.platform.startswith("win32"):
53+
lib_name = f"{name}.dll"
54+
elif sys.platform.startswith("darwin"):
55+
lib_name = f"lib{name}.dylib"
56+
else:
57+
lib_name = f"lib{name}.so"
58+
59+
dll_paths = get_dll_directories()
60+
lib_dll_path = [os.path.join(p, lib_name) for p in dll_paths]
61+
lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)]
62+
if not lib_found:
63+
if not optional:
64+
message = (
65+
f"Cannot find libraries: {lib_name}\n"
66+
+ "List of candidates:\n"
67+
+ "\n".join(lib_dll_path)
68+
)
69+
raise RuntimeError(message)
70+
return lib_found
Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from chat_module import LLMChatModule, supported_models, quantization_keys
1+
from .chat_module import ChatModule, supported_models, quantization_keys
22

33
from pydantic import BaseModel
44
from fastapi import FastAPI, HTTPException
@@ -15,20 +15,14 @@
1515

1616
session = {}
1717

18+
1819
@asynccontextmanager
1920
async def lifespan(app: FastAPI):
2021

2122
ARGS = _parse_args()
2223

23-
chat_mod = LLMChatModule(
24-
ARGS.mlc_lib_path,
25-
ARGS.device_name,
26-
ARGS.device_id
27-
)
28-
model_path = os.path.join(
29-
ARGS.artifact_path,
30-
ARGS.model + "-" + ARGS.quantization
31-
)
24+
chat_mod = ChatModule(ARGS.mlc_lib_path, ARGS.device_name, ARGS.device_id)
25+
model_path = os.path.join(ARGS.artifact_path, ARGS.model + "-" + ARGS.quantization)
3226
model_dir = ARGS.model + "-" + ARGS.quantization
3327
model_lib = model_dir + "-" + ARGS.device_name + ".so"
3428
lib_dir = os.path.join(model_path, model_lib)
@@ -38,16 +32,22 @@ async def lifespan(app: FastAPI):
3832
elif os.path.exists(prebuilt_lib_dir):
3933
lib = tvm.runtime.load_module(prebuilt_lib_dir)
4034
else:
41-
raise ValueError(f"Unable to find {model_lib} at {lib_dir} or {prebuilt_lib_dir}.")
35+
raise ValueError(
36+
f"Unable to find {model_lib} at {lib_dir} or {prebuilt_lib_dir}."
37+
)
4238

4339
local_model_path = os.path.join(model_path, "params")
44-
prebuilt_model_path = os.path.join(ARGS.artifact_path, "prebuilt", f"mlc-chat-{model_dir}")
40+
prebuilt_model_path = os.path.join(
41+
ARGS.artifact_path, "prebuilt", f"mlc-chat-{model_dir}"
42+
)
4543
if os.path.exists(local_model_path):
4644
chat_mod.reload(lib=lib, model_path=local_model_path)
4745
elif os.path.exists(prebuilt_model_path):
4846
chat_mod.reload(lib=lib, model_path=prebuilt_model_path)
4947
else:
50-
raise ValueError(f"Unable to find model params at {local_model_path} or {prebuilt_model_path}.")
48+
raise ValueError(
49+
f"Unable to find model params at {local_model_path} or {prebuilt_model_path}."
50+
)
5151
session["chat_mod"] = chat_mod
5252

5353
yield
@@ -57,13 +57,11 @@ async def lifespan(app: FastAPI):
5757

5858
app = FastAPI(lifespan=lifespan)
5959

60+
6061
def _parse_args():
6162
args = argparse.ArgumentParser()
6263
args.add_argument(
63-
"--model",
64-
type=str,
65-
choices=supported_models(),
66-
default="vicuna-v1-7b"
64+
"--model", type=str, choices=supported_models(), default="vicuna-v1-7b"
6765
)
6866
args.add_argument("--artifact-path", type=str, default="dist")
6967
args.add_argument(
@@ -85,65 +83,74 @@ def _parse_args():
8583
"""
8684
List the currently supported models and provides basic information about each of them.
8785
"""
86+
87+
8888
@app.get("/models")
8989
async def read_models():
90-
return {
91-
"data": [{
92-
"id": model,
93-
"object":"model"
94-
} for model in supported_models()]
95-
}
90+
return {"data": [{"id": model, "object": "model"} for model in supported_models()]}
91+
9692

9793
"""
9894
Retrieve a model instance with basic information about the model.
9995
"""
96+
97+
10098
@app.get("/models/{model}")
10199
async def read_model(model: str):
102100
if model not in supported_models():
103101
raise HTTPException(status_code=404, detail=f"Model {model} is not supported.")
104-
return {
105-
"id": model,
106-
"object":"model"
107-
}
102+
return {"id": model, "object": "model"}
103+
108104

109105
class ChatRequest(BaseModel):
110106
prompt: str
111107
stream: bool = False
112108

109+
113110
"""
114111
Creates model response for the given chat conversation.
115112
"""
113+
114+
116115
@app.post("/chat/completions")
117116
def request_completion(request: ChatRequest):
118117
session["chat_mod"].prefill(input=request.prompt)
119118
if request.stream:
119+
120120
def iter_response():
121121
while not session["chat_mod"].stopped():
122122
session["chat_mod"].decode()
123123
msg = session["chat_mod"].get_message()
124124
yield json.dumps({"message": msg})
125-
return StreamingResponse(iter_response(), media_type='application/json')
125+
126+
return StreamingResponse(iter_response(), media_type="application/json")
126127
else:
127128
msg = None
128129
while not session["chat_mod"].stopped():
129130
session["chat_mod"].decode()
130131
msg = session["chat_mod"].get_message()
131132
return {"message": msg}
132133

134+
133135
"""
134136
Reset the chat for the currently initialized model.
135137
"""
138+
139+
136140
@app.post("/chat/reset")
137141
def reset():
138142
session["chat_mod"].reset_chat()
139143

144+
140145
"""
141146
Get the runtime stats.
142147
"""
148+
149+
143150
@app.get("/stats")
144151
def read_stats():
145152
return session["chat_mod"].runtime_stats_text()
146153

147154

148155
if __name__ == "__main__":
149-
uvicorn.run("server:app", port=8000, reload=True, access_log=False)
156+
uvicorn.run("mlc_chat.server:app", port=8000, reload=True, access_log=False)

0 commit comments

Comments
 (0)