mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
fix stream: false completion
This commit is contained in:
@@ -204,7 +204,7 @@ class ChatGPTAPI:
|
||||
|
||||
# Get the callback system and register our handler
|
||||
self.token_callback = node.on_token.register("chatgpt-api-token-handler")
|
||||
self.token_callback.on_next(lambda _request_id, token, is_finished: asyncio.create_task(self.handle_token(_request_id, token, is_finished)))
|
||||
self.token_callback.on_next(lambda _request_id, tokens, is_finished: asyncio.create_task(self.handle_tokens(_request_id, tokens, is_finished)))
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
cors = aiohttp_cors.setup(self.app)
|
||||
@@ -463,17 +463,17 @@ class ChatGPTAPI:
|
||||
else:
|
||||
tokens = []
|
||||
while True:
|
||||
token, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
|
||||
tokens.append(token)
|
||||
_tokens, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
|
||||
tokens.extend(_tokens)
|
||||
if is_finished:
|
||||
break
|
||||
finish_reason = "length"
|
||||
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
|
||||
if DEBUG >= 2: print(f"Checking if end of tokens result {token=} is {eos_token_id=}")
|
||||
if token == eos_token_id:
|
||||
if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
|
||||
if tokens[-1] == eos_token_id:
|
||||
finish_reason = "stop"
|
||||
|
||||
return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, [token], stream, finish_reason, "chat.completion"))
|
||||
return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
|
||||
except asyncio.TimeoutError:
|
||||
return web.json_response({"detail": "Response generation timed out"}, status=408)
|
||||
except Exception as e:
|
||||
@@ -678,8 +678,8 @@ class ChatGPTAPI:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
|
||||
|
||||
async def handle_token(self, request_id: str, token: int, is_finished: bool):
|
||||
await self.token_queues[request_id].put((token, is_finished))
|
||||
async def handle_tokens(self, request_id: str, tokens: List[int], is_finished: bool):
|
||||
await self.token_queues[request_id].put((tokens, is_finished))
|
||||
|
||||
async def run(self, host: str = "0.0.0.0", port: int = 52415):
|
||||
runner = web.AppRunner(self.app)
|
||||
|
||||
Reference in New Issue
Block a user