Skip to content

Commit

Permalink
fix image api prompt encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Jul 30, 2024
1 parent 2d20000 commit 178fb75
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 41 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ For developers, exo also starts a ChatGPT-compatible API endpoint on http://loca
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama-3-8b",
"model": "llama-3.1-8b",
"messages": [{"role": "user", "content": "What is the meaning of exo?"}],
"temperature": 0.7
}'
Expand Down
50 changes: 36 additions & 14 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,29 @@


class Message:
def __init__(self, role: str, content: Union[str, list]):
self.role = role
self.content = content
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
self.role = role
self.content = content

def to_dict(self):
return {
"role": self.role,
"content": self.content
}


class ChatCompletionRequest:
def __init__(self, model: str, messages: List[Message], temperature: float):
self.model = model
self.messages = messages
self.temperature = temperature
def __init__(self, model: str, messages: List[Message], temperature: float):
self.model = model
self.messages = messages
self.temperature = temperature

def to_dict(self):
return {
"model": self.model,
"messages": [message.to_dict() for message in self.messages],
"temperature": self.temperature
}


def resolve_tinygrad_tokenizer(model_id: str):
Expand All @@ -75,8 +88,12 @@ async def resolve_tokenizer(model_id: str):
try:
if DEBUG >= 2: print(f"Trying AutoProcessor for {model_id}")
processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
processor.eos_token_id = processor.tokenizer.eos_token_id
processor.encode = processor.tokenizer.encode
if not hasattr(processor, 'eos_token_id'):
processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
if not hasattr(processor, 'encode'):
processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
if not hasattr(processor, 'decode'):
processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
return processor
except Exception as e:
if DEBUG >= 2: print(f"Failed to load processor for {model_id}. Error: {e}")
Expand Down Expand Up @@ -157,6 +174,10 @@ def remap_messages(messages: List[Message]) -> List[Message]:
remapped_messages = []
last_image = None
for message in messages:
if not isinstance(message.content, list):
remapped_messages.append(message)
continue

remapped_content = []
for content in message.content:
if isinstance(content, dict):
Expand All @@ -168,16 +189,17 @@ def remap_messages(messages: List[Message]) -> List[Message]:
else:
remapped_content.append(content)
else:
remapped_content.append({"type": "text", "text": content})
remapped_content.append(content)
remapped_messages.append(Message(role=message.role, content=remapped_content))

if last_image:
# Replace the last image placeholder with the actual image content
for message in reversed(remapped_messages):
for i, content in enumerate(message.content):
if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
message.content[i] = last_image
return remapped_messages
if isinstance(content, dict):
if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
message.content[i] = last_image
return remapped_messages

return remapped_messages

Expand All @@ -192,7 +214,7 @@ def build_prompt(tokenizer, _messages: List[Message]):
for content in message.content:
# note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41
# follows the convention in https://platform.openai.com/docs/guides/vision
if content.get("type", None) == "image":
if isinstance(content, dict) and content.get("type", None) == "image":
image_str = content.get("image", None)
break

Expand Down
74 changes: 48 additions & 26 deletions tinychat/examples/tinychat/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ document.addEventListener("alpine:init", () => {
this.tokens_per_second = 0;

// prepare messages for API request
const apiMessages = this.cstate.messages.map(msg => {
let apiMessages = this.cstate.messages.map(msg => {
if (msg.content.startsWith('![Uploaded Image]')) {
return {
role: "user",
Expand All @@ -89,36 +89,40 @@ document.addEventListener("alpine:init", () => {
image_url: {
url: this.imageUrl
}
},
{
type: "text",
text: value // Use the actual text the user typed
}
]
};
} else {
return {
role: msg.role,
content: [
{
type: "text",
text: msg.content
}
]
content: msg.content
};
}
});

// If there's an image URL, add it to all messages
if (this.imageUrl) {
apiMessages.forEach(msg => {
if (!msg.content.some(content => content.type === "image_url")) {
msg.content.push({
type: "image_url",
image_url: {
url: this.imageUrl
}
});
const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
if (containsImage) {
// Map all messages with string content to object with type text
apiMessages = apiMessages.map(msg => {
if (typeof msg.content === 'string') {
return {
...msg,
content: [
{
type: "text",
text: msg.content
}
]
};
}
return msg;
});
}


// start receiving server sent events
let gottenFirstChunk = false;
for await (
Expand Down Expand Up @@ -146,19 +150,37 @@ document.addEventListener("alpine:init", () => {
}
}

// update the state in histories or add it if it doesn't exist
const index = this.histories.findIndex((cstate) => {
return cstate.time === this.cstate.time;
// Clean the cstate before adding it to histories
const cleanedCstate = JSON.parse(JSON.stringify(this.cstate));
cleanedCstate.messages = cleanedCstate.messages.map(msg => {
if (Array.isArray(msg.content)) {
return {
...msg,
content: msg.content.map(item =>
item.type === 'image_url' ? { type: 'image_url', image_url: { url: '[IMAGE_PLACEHOLDER]' } } : item
)
};
}
return msg;
});
this.cstate.time = Date.now();

// Update the state in histories or add it if it doesn't exist
const index = this.histories.findIndex((cstate) => cstate.time === cleanedCstate.time);
cleanedCstate.time = Date.now();
if (index !== -1) {
// update the time
this.histories[index] = this.cstate;
// Update the existing entry
this.histories[index] = cleanedCstate;
} else {
this.histories.push(this.cstate);
// Add a new entry
this.histories.push(cleanedCstate);
}
console.log(this.histories)
// update in local storage
localStorage.setItem("histories", JSON.stringify(this.histories));
try {
localStorage.setItem("histories", JSON.stringify(this.histories));
} catch (error) {
console.error("Failed to save histories to localStorage:", error);
}

this.generating = false;
},
Expand Down

0 comments on commit 178fb75

Please sign in to comment.