diff --git a/README.md b/README.md index 639a80be0..eb31e4e50 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,7 @@ For developers, exo also starts a ChatGPT-compatible API endpoint on http://loca curl http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ - "model": "llama-3-8b", + "model": "llama-3.1-8b", "messages": [{"role": "user", "content": "What is the meaning of exo?"}], "temperature": 0.7 }' diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 08e4cd762..2d80f2d5a 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -50,16 +50,29 @@ class Message: - def __init__(self, role: str, content: Union[str, list]): - self.role = role - self.content = content + def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]): + self.role = role + self.content = content + + def to_dict(self): + return { + "role": self.role, + "content": self.content + } class ChatCompletionRequest: - def __init__(self, model: str, messages: List[Message], temperature: float): - self.model = model - self.messages = messages - self.temperature = temperature + def __init__(self, model: str, messages: List[Message], temperature: float): + self.model = model + self.messages = messages + self.temperature = temperature + + def to_dict(self): + return { + "model": self.model, + "messages": [message.to_dict() for message in self.messages], + "temperature": self.temperature + } def resolve_tinygrad_tokenizer(model_id: str): @@ -75,8 +88,12 @@ async def resolve_tokenizer(model_id: str): try: if DEBUG >= 2: print(f"Trying AutoProcessor for {model_id}") processor = AutoProcessor.from_pretrained(model_id, use_fast=False) - processor.eos_token_id = processor.tokenizer.eos_token_id - processor.encode = processor.tokenizer.encode + if not hasattr(processor, 'eos_token_id'): + processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id + if not hasattr(processor, 'encode'): + processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode + if not hasattr(processor, 'decode'): + processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode return processor except Exception as e: if DEBUG >= 2: print(f"Failed to load processor for {model_id}. Error: {e}") @@ -157,6 +174,10 @@ def remap_messages(messages: List[Message]) -> List[Message]: remapped_messages = [] last_image = None for message in messages: + if not isinstance(message.content, list): + remapped_messages.append(message) + continue + remapped_content = [] for content in message.content: if isinstance(content, dict): @@ -168,16 +189,17 @@ def remap_messages(messages: List[Message]) -> List[Message]: else: remapped_content.append(content) else: - remapped_content.append({"type": "text", "text": content}) + remapped_content.append(content) remapped_messages.append(Message(role=message.role, content=remapped_content)) if last_image: # Replace the last image placeholder with the actual image content for message in reversed(remapped_messages): for i, content in enumerate(message.content): - if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]": - message.content[i] = last_image - return remapped_messages + if isinstance(content, dict): + if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]": + message.content[i] = last_image + return remapped_messages return remapped_messages @@ -192,7 +214,7 @@ def build_prompt(tokenizer, _messages: List[Message]): for content in message.content: # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41 # follows the convention in https://platform.openai.com/docs/guides/vision - if content.get("type", None) == "image": + if isinstance(content, dict) and content.get("type", None) == "image": image_str = content.get("image", None) break diff --git a/tinychat/examples/tinychat/index.js b/tinychat/examples/tinychat/index.js index 1a15be04f..1cbcf8701 100644 --- a/tinychat/examples/tinychat/index.js +++ b/tinychat/examples/tinychat/index.js @@ -79,7 +79,7 @@ document.addEventListener("alpine:init", () => { this.tokens_per_second = 0; // prepare messages for API request - const apiMessages = this.cstate.messages.map(msg => { + let apiMessages = this.cstate.messages.map(msg => { if (msg.content.startsWith('![Uploaded Image]')) { return { role: "user", @@ -89,36 +89,40 @@ document.addEventListener("alpine:init", () => { image_url: { url: this.imageUrl } + }, + { + type: "text", + text: value // Use the actual text the user typed } ] }; } else { return { role: msg.role, - content: [ - { - type: "text", - text: msg.content - } - ] + content: msg.content }; } }); - - // If there's an image URL, add it to all messages - if (this.imageUrl) { - apiMessages.forEach(msg => { - if (!msg.content.some(content => content.type === "image_url")) { - msg.content.push({ - type: "image_url", - image_url: { - url: this.imageUrl - } - }); + const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url')); + if (containsImage) { + // Map all messages with string content to object with type text + apiMessages = apiMessages.map(msg => { + if (typeof msg.content === 'string') { + return { + ...msg, + content: [ + { + type: "text", + text: msg.content + } + ] + }; } + return msg; }); } + // start receiving server sent events let gottenFirstChunk = false; for await ( @@ -146,19 +150,37 @@ document.addEventListener("alpine:init", () => { } } - // update the state in histories or add it if it doesn't exist - const index = this.histories.findIndex((cstate) => { - return cstate.time === this.cstate.time; + // Clean the cstate before adding it to histories + const cleanedCstate = JSON.parse(JSON.stringify(this.cstate)); + cleanedCstate.messages = cleanedCstate.messages.map(msg => { + if (Array.isArray(msg.content)) { + return { + ...msg, + content: msg.content.map(item => + item.type === 'image_url' ? { type: 'image_url', image_url: { url: '[IMAGE_PLACEHOLDER]' } } : item + ) + }; + } + return msg; }); - this.cstate.time = Date.now(); + + // Update the state in histories or add it if it doesn't exist + const index = this.histories.findIndex((cstate) => cstate.time === cleanedCstate.time); + cleanedCstate.time = Date.now(); if (index !== -1) { - // update the time - this.histories[index] = this.cstate; + // Update the existing entry + this.histories[index] = cleanedCstate; } else { - this.histories.push(this.cstate); + // Add a new entry + this.histories.push(cleanedCstate); } + console.log(this.histories) // update in local storage - localStorage.setItem("histories", JSON.stringify(this.histories)); + try { + localStorage.setItem("histories", JSON.stringify(this.histories)); + } catch (error) { + console.error("Failed to save histories to localStorage:", error); + } this.generating = false; },