mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
fix image api prompt encoding
This commit is contained in:
@@ -117,7 +117,7 @@ For developers, exo also starts a ChatGPT-compatible API endpoint on http://loca
|
||||
curl http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llama-3-8b",
|
||||
"model": "llama-3.1-8b",
|
||||
"messages": [{"role": "user", "content": "What is the meaning of exo?"}],
|
||||
"temperature": 0.7
|
||||
}'
|
||||
|
||||
@@ -50,16 +50,29 @@ shard_mappings = {
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, role: str, content: Union[str, list]):
|
||||
self.role = role
|
||||
self.content = content
|
||||
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
||||
self.role = role
|
||||
self.content = content
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"role": self.role,
|
||||
"content": self.content
|
||||
}
|
||||
|
||||
|
||||
class ChatCompletionRequest:
|
||||
def __init__(self, model: str, messages: List[Message], temperature: float):
|
||||
self.model = model
|
||||
self.messages = messages
|
||||
self.temperature = temperature
|
||||
def __init__(self, model: str, messages: List[Message], temperature: float):
|
||||
self.model = model
|
||||
self.messages = messages
|
||||
self.temperature = temperature
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"model": self.model,
|
||||
"messages": [message.to_dict() for message in self.messages],
|
||||
"temperature": self.temperature
|
||||
}
|
||||
|
||||
|
||||
def resolve_tinygrad_tokenizer(model_id: str):
|
||||
@@ -75,8 +88,12 @@ async def resolve_tokenizer(model_id: str):
|
||||
try:
|
||||
if DEBUG >= 2: print(f"Trying AutoProcessor for {model_id}")
|
||||
processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
|
||||
processor.eos_token_id = processor.tokenizer.eos_token_id
|
||||
processor.encode = processor.tokenizer.encode
|
||||
if not hasattr(processor, 'eos_token_id'):
|
||||
processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
|
||||
if not hasattr(processor, 'encode'):
|
||||
processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
|
||||
if not hasattr(processor, 'decode'):
|
||||
processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
|
||||
return processor
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Failed to load processor for {model_id}. Error: {e}")
|
||||
@@ -157,6 +174,10 @@ def remap_messages(messages: List[Message]) -> List[Message]:
|
||||
remapped_messages = []
|
||||
last_image = None
|
||||
for message in messages:
|
||||
if not isinstance(message.content, list):
|
||||
remapped_messages.append(message)
|
||||
continue
|
||||
|
||||
remapped_content = []
|
||||
for content in message.content:
|
||||
if isinstance(content, dict):
|
||||
@@ -168,16 +189,17 @@ def remap_messages(messages: List[Message]) -> List[Message]:
|
||||
else:
|
||||
remapped_content.append(content)
|
||||
else:
|
||||
remapped_content.append({"type": "text", "text": content})
|
||||
remapped_content.append(content)
|
||||
remapped_messages.append(Message(role=message.role, content=remapped_content))
|
||||
|
||||
if last_image:
|
||||
# Replace the last image placeholder with the actual image content
|
||||
for message in reversed(remapped_messages):
|
||||
for i, content in enumerate(message.content):
|
||||
if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
|
||||
message.content[i] = last_image
|
||||
return remapped_messages
|
||||
if isinstance(content, dict):
|
||||
if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
|
||||
message.content[i] = last_image
|
||||
return remapped_messages
|
||||
|
||||
return remapped_messages
|
||||
|
||||
@@ -192,7 +214,7 @@ def build_prompt(tokenizer, _messages: List[Message]):
|
||||
for content in message.content:
|
||||
# note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41
|
||||
# follows the convention in https://platform.openai.com/docs/guides/vision
|
||||
if content.get("type", None) == "image":
|
||||
if isinstance(content, dict) and content.get("type", None) == "image":
|
||||
image_str = content.get("image", None)
|
||||
break
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ document.addEventListener("alpine:init", () => {
|
||||
this.tokens_per_second = 0;
|
||||
|
||||
// prepare messages for API request
|
||||
const apiMessages = this.cstate.messages.map(msg => {
|
||||
let apiMessages = this.cstate.messages.map(msg => {
|
||||
if (msg.content.startsWith('![Uploaded Image]')) {
|
||||
return {
|
||||
role: "user",
|
||||
@@ -89,36 +89,40 @@ document.addEventListener("alpine:init", () => {
|
||||
image_url: {
|
||||
url: this.imageUrl
|
||||
}
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: value // Use the actual text the user typed
|
||||
}
|
||||
]
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
role: msg.role,
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: msg.content
|
||||
}
|
||||
]
|
||||
content: msg.content
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
// If there's an image URL, add it to all messages
|
||||
if (this.imageUrl) {
|
||||
apiMessages.forEach(msg => {
|
||||
if (!msg.content.some(content => content.type === "image_url")) {
|
||||
msg.content.push({
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: this.imageUrl
|
||||
}
|
||||
});
|
||||
const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
|
||||
if (containsImage) {
|
||||
// Map all messages with string content to object with type text
|
||||
apiMessages = apiMessages.map(msg => {
|
||||
if (typeof msg.content === 'string') {
|
||||
return {
|
||||
...msg,
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: msg.content
|
||||
}
|
||||
]
|
||||
};
|
||||
}
|
||||
return msg;
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
// start receiving server sent events
|
||||
let gottenFirstChunk = false;
|
||||
for await (
|
||||
@@ -146,19 +150,37 @@ document.addEventListener("alpine:init", () => {
|
||||
}
|
||||
}
|
||||
|
||||
// update the state in histories or add it if it doesn't exist
|
||||
const index = this.histories.findIndex((cstate) => {
|
||||
return cstate.time === this.cstate.time;
|
||||
// Clean the cstate before adding it to histories
|
||||
const cleanedCstate = JSON.parse(JSON.stringify(this.cstate));
|
||||
cleanedCstate.messages = cleanedCstate.messages.map(msg => {
|
||||
if (Array.isArray(msg.content)) {
|
||||
return {
|
||||
...msg,
|
||||
content: msg.content.map(item =>
|
||||
item.type === 'image_url' ? { type: 'image_url', image_url: { url: '[IMAGE_PLACEHOLDER]' } } : item
|
||||
)
|
||||
};
|
||||
}
|
||||
return msg;
|
||||
});
|
||||
this.cstate.time = Date.now();
|
||||
|
||||
// Update the state in histories or add it if it doesn't exist
|
||||
const index = this.histories.findIndex((cstate) => cstate.time === cleanedCstate.time);
|
||||
cleanedCstate.time = Date.now();
|
||||
if (index !== -1) {
|
||||
// update the time
|
||||
this.histories[index] = this.cstate;
|
||||
// Update the existing entry
|
||||
this.histories[index] = cleanedCstate;
|
||||
} else {
|
||||
this.histories.push(this.cstate);
|
||||
// Add a new entry
|
||||
this.histories.push(cleanedCstate);
|
||||
}
|
||||
console.log(this.histories)
|
||||
// update in local storage
|
||||
localStorage.setItem("histories", JSON.stringify(this.histories));
|
||||
try {
|
||||
localStorage.setItem("histories", JSON.stringify(this.histories));
|
||||
} catch (error) {
|
||||
console.error("Failed to save histories to localStorage:", error);
|
||||
}
|
||||
|
||||
this.generating = false;
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user