mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
add chatgpt-api-compatible tools for function calling
This commit is contained in:
@@ -5,7 +5,7 @@ import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from transformers import AutoTokenizer
|
||||
from typing import List, Literal, Union, Dict
|
||||
from typing import List, Literal, Union, Dict, Optional
|
||||
from aiohttp import web
|
||||
import aiohttp_cors
|
||||
import traceback
|
||||
@@ -23,23 +23,28 @@ from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
|
||||
from exo.apputil import create_animation_mp4
|
||||
|
||||
class Message:
|
||||
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
||||
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.tools = tools
|
||||
|
||||
def to_dict(self):
|
||||
return {"role": self.role, "content": self.content}
|
||||
data = {"role": self.role, "content": self.content}
|
||||
if self.tools:
|
||||
data["tools"] = self.tools
|
||||
return data
|
||||
|
||||
|
||||
|
||||
class ChatCompletionRequest:
|
||||
def __init__(self, model: str, messages: List[Message], temperature: float):
|
||||
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
|
||||
self.model = model
|
||||
self.messages = messages
|
||||
self.temperature = temperature
|
||||
self.tools = tools
|
||||
|
||||
def to_dict(self):
|
||||
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
|
||||
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
|
||||
|
||||
|
||||
def generate_completion(
|
||||
@@ -119,20 +124,24 @@ def remap_messages(messages: List[Message]) -> List[Message]:
|
||||
return remapped_messages
|
||||
|
||||
|
||||
def build_prompt(tokenizer, _messages: List[Message]):
|
||||
def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
|
||||
messages = remap_messages(_messages)
|
||||
prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
|
||||
for message in messages:
|
||||
if not isinstance(message.content, list):
|
||||
continue
|
||||
chat_template_args = {
|
||||
"conversation": [m.to_dict() for m in messages],
|
||||
"tokenize": False,
|
||||
"add_generation_prompt": True
|
||||
}
|
||||
if tools: chat_template_args["tools"] = tools
|
||||
|
||||
prompt = tokenizer.apply_chat_template(**chat_template_args)
|
||||
print(f"!!! Prompt: {prompt}")
|
||||
return prompt
|
||||
|
||||
|
||||
def parse_message(data: dict):
|
||||
if "role" not in data or "content" not in data:
|
||||
raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
|
||||
return Message(data["role"], data["content"])
|
||||
return Message(data["role"], data["content"], data.get("tools"))
|
||||
|
||||
|
||||
def parse_chat_request(data: dict, default_model: str):
|
||||
@@ -140,6 +149,7 @@ def parse_chat_request(data: dict, default_model: str):
|
||||
data.get("model", default_model),
|
||||
[parse_message(msg) for msg in data["messages"]],
|
||||
data.get("temperature", 0.0),
|
||||
data.get("tools", None),
|
||||
)
|
||||
|
||||
|
||||
@@ -287,7 +297,7 @@ class ChatGPTAPI:
|
||||
shard = build_base_shard(model, self.inference_engine_classname)
|
||||
messages = [parse_message(msg) for msg in data.get("messages", [])]
|
||||
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
||||
prompt = build_prompt(tokenizer, messages)
|
||||
prompt = build_prompt(tokenizer, messages, data.get("tools", None))
|
||||
tokens = tokenizer.encode(prompt)
|
||||
return web.json_response({
|
||||
"length": len(prompt),
|
||||
@@ -326,7 +336,7 @@ class ChatGPTAPI:
|
||||
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
||||
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
|
||||
|
||||
prompt = build_prompt(tokenizer, chat_request.messages)
|
||||
prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
|
||||
request_id = str(uuid.uuid4())
|
||||
if self.on_chat_completion_request:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user