Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LLama CPP Support #335

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,14 @@ async def handle_post_chat_completions(self, request):
{"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
status=400,
)

tokenizer = await resolve_tokenizer(shard.model_id)
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")

prompt, image_str = build_prompt(tokenizer, chat_request.messages)
if shard.model_id.endswith("GGUF"):
prompt = chat_request.messages[0].content
image_str = None
else:
tokenizer = await resolve_tokenizer(shard.model_id)
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")

prompt, image_str = build_prompt(tokenizer, chat_request.messages)
request_id = str(uuid.uuid4())
if self.on_chat_completion_request:
try:
Expand Down
3 changes: 3 additions & 0 deletions exo/download/hf/hf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
shard_specific_patterns.add(sorted_file_names[0])
elif shard.is_last_layer():
shard_specific_patterns.add(sorted_file_names[-1])
# TODO: Support more models in a cleaner manner
elif shard.model_id.endswith("GGUF"):
return ["Meta-Llama-3.1-8B-Instruct-Q8_0.gguf"]
else:
shard_specific_patterns = set("*.safetensors")
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
Expand Down
4 changes: 4 additions & 0 deletions exo/inference/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from typing import Tuple, Optional
from abc import ABC, abstractmethod

from .shard import Shard


Expand All @@ -27,5 +28,8 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))

return TinygradDynamicShardInferenceEngine(shard_downloader)
elif inference_engine_name == "llama_cpp":
from exo.inference.llama_cpp.inference import LLamaInferenceEngine
return LLamaInferenceEngine(shard_downloader)
else:
raise ValueError(f"Inference engine {inference_engine_name} not supported")
1 change: 1 addition & 0 deletions exo/inference/llama_cpp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

17 changes: 17 additions & 0 deletions exo/inference/llama_cpp/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from exo.inference.shard import Shard
from exo.inference.inference_engine import InferenceEngine
from exo.download.shard_download import ShardDownloader
from typing import Tuple, Optional
import numpy as np
from concurrent.futures import ThreadPoolExecutor

class LLamaInferenceEngine(InferenceEngine):
def __init__(self, shard_downloader: ShardDownloader):
self.shard_downloader = shard_downloader
self.executor = ThreadPoolExecutor(max_workers=1)

def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
pass

def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
pass
7 changes: 7 additions & 0 deletions exo/inference/llama_cpp/models/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import ggml

# TODO: adjust size to fit whatever GGUF model being used
# Q8 LLama 3.1 8b
mem_size = int(1e9) * 8

PARAMS = ggml.ggml.ggml_init_params(mem_size, mem_buffer=None)
65 changes: 65 additions & 0 deletions exo/inference/llama_cpp/models/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from dataclasses import dataclass
from importlib.resources import contents
import ggml
import numpy as np
from typing import Optional

@dataclass
class OpsOutput:
out_tensor: ggml.ggml_tensor_p
out_shape: tuple

class Linear:
def __init__(self, ctx: ggml.ggml_context_p, graph: ggml.ggml_cgraph_p, w: np.ndarray, bias: Optional[np.ndarray], backend: ggml.ggml_backend_t):
self.ctx = ctx
self.gf = graph
self.w = w
self.backend = backend
# Tensors should be inserted in row column order is
self.w_tensor = ggml.ggml_new_tensor_2d(self.ctx, ggml.GGML_TYPE_F32, self.w.shape[1], self.w.shape[0])
if bias is not None:
self.bias = bias
self.bias_tensor = ggml.ggml_new_tensor_2d(self.ctx, ggml.GGML_TYPE_F32, self.bias.shape[0], self.bias.shape[1])

def forward(self, x: np.ndarray) -> OpsOutput:
# TODO: better check as matmul in ggml internally transposes
# n x n matrix is inversed therefore a transpose is required here if they're equal
if self.w.shape[0] != x.shape[0]:
x = x.T
x_tensor = ggml.ggml_new_tensor_2d(self.ctx, ggml.GGML_TYPE_F32, x.shape[1], x.shape[0])
if hasattr(self, "bias") and hasattr(self, "bias_tensor"):
x2_tensor = ggml.ggml_mul_mat(self.ctx, self.w_tensor, x_tensor)
forward_tensor = ggml.ggml_add(self.ctx, x2_tensor, self.bias_tensor)
else:
forward_tensor = ggml.ggml_mul_mat(self.ctx, self.w_tensor, x_tensor)

ggml.ggml_build_forward_expand(self.gf, forward_tensor)
buffer = ggml.ggml_backend_alloc_ctx_tensors(self.ctx, self.backend)
ggml.ggml_backend_tensor_set(self.w_tensor, self.w.ctypes.data, 0, ggml.ggml_nbytes(self.w_tensor))
ggml.ggml_backend_tensor_set(x_tensor, x.ctypes.data, 0,ggml.ggml_nbytes(x_tensor))

if hasattr(self, "bias") and hasattr(self, "bias_tensor"):
ggml.ggml_backend_tensor_set(self.bias_tensor, self.bias.ctypes.data, 0, ggml.ggml_nbytes(self.bias_tensor))

out_shape = (forward_tensor.contents.ne[0], forward_tensor.contents.ne[1])
print(out_shape)
return OpsOutput(out_tensor=forward_tensor, out_shape=out_shape)

# Quick Test to see if ops work
if __name__== "__main__":
backend = ggml.ggml_backend_cpu_init()
params = ggml.ggml_init_params(
mem_size=ggml.ggml_tensor_overhead() * 6 + ggml.ggml_graph_overhead() + 10000,
no_alloc=True,
)
ctx = ggml.ggml_init(params)
gf = ggml.ggml_new_graph(ctx)
x = np.array([[4, 5]], dtype=np.float32)
w = np.array([[2, 3]], dtype=np.float32)

linear = Linear(ctx, gf, w,None, backend)
output_op = linear.forward(x)
out_array = np.empty(output_op.out_shape, dtype=np.float32)
ggml.ggml_backend_graph_compute(backend, gf)
ggml.ggml_backend_tensor_get(output_op.out_tensor, out_array.ctypes.data,0,ggml.ggml_nbytes(output_op.out_tensor))
print(out_array)
8 changes: 8 additions & 0 deletions exo/inference/llama_cpp/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import gguf
import ggml
import numpy as np

def load_weights(file_data, weight_name) -> np.ndarray:
_, tensorinfo = gguf.load_gguf(file_data)
numpy_tensor = gguf.load_gguf_tensor(file_data, tensorinfo, weight_name)
return numpy_tensor
1 change: 1 addition & 0 deletions exo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"llama-3.1-8b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
"TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
"LLamaInferenceEngine": Shard(model_id="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", start_layer=0, end_layer=80, n_layers=32)
},
"llama-3.1-70b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
Expand Down
2 changes: 2 additions & 0 deletions install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@

python3 -m venv .venv
source .venv/bin/activate
# ggml has to be installed seperately due to shared library dependency
pip install ggml-python --config-settings=cmake.args='-DGGML_CUDA=ON;-DGGML_METAL=ON'
pip install -e .
4 changes: 4 additions & 0 deletions pyrightconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"venvPath": ".",
"venv": ".venv"
}
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"aiohttp==3.10.2",
"aiohttp_cors==0.7.0",
"aiofiles==24.1.0",
"gguf @ git+https://github.com/99991/pygguf@f304361f69ce795ad06103b697fe6eaf44262259",
"grpcio==1.64.1",
"grpcio-tools==1.64.1",
"Jinja2==3.1.4",
Expand Down