fix stream: false completion

This commit is contained in:
Alex Cheema
2025-01-22 22:46:04 +00:00
parent 55d1846f5e
commit 87d1271d33

View File

@@ -204,7 +204,7 @@ class ChatGPTAPI:
# Get the callback system and register our handler
self.token_callback = node.on_token.register("chatgpt-api-token-handler")
self.token_callback.on_next(lambda _request_id, token, is_finished: asyncio.create_task(self.handle_token(_request_id, token, is_finished)))
self.token_callback.on_next(lambda _request_id, tokens, is_finished: asyncio.create_task(self.handle_tokens(_request_id, tokens, is_finished)))
self.system_prompt = system_prompt
cors = aiohttp_cors.setup(self.app)
@@ -463,17 +463,17 @@ class ChatGPTAPI:
else:
tokens = []
while True:
token, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
tokens.append(token)
_tokens, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
tokens.extend(_tokens)
if is_finished:
break
finish_reason = "length"
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
if DEBUG >= 2: print(f"Checking if end of tokens result {token=} is {eos_token_id=}")
if token == eos_token_id:
if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
if tokens[-1] == eos_token_id:
finish_reason = "stop"
return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, [token], stream, finish_reason, "chat.completion"))
return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
except asyncio.TimeoutError:
return web.json_response({"detail": "Response generation timed out"}, status=408)
except Exception as e:
@@ -678,8 +678,8 @@ class ChatGPTAPI:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
async def handle_token(self, request_id: str, token: int, is_finished: bool):
await self.token_queues[request_id].put((token, is_finished))
async def handle_tokens(self, request_id: str, tokens: List[int], is_finished: bool):
await self.token_queues[request_id].put((tokens, is_finished))
async def run(self, host: str = "0.0.0.0", port: int = 52415):
runner = web.AppRunner(self.app)