mirror of
https://github.com/pinecone-io/examples.git
synced 2023-10-11 20:04:54 +03:00
104 lines
2.9 KiB
Python
104 lines
2.9 KiB
Python
import os
|
|
import asyncio
|
|
from typing import Any
|
|
|
|
import uvicorn
|
|
from fastapi import FastAPI, Body
|
|
from fastapi.responses import StreamingResponse
|
|
from queue import Queue
|
|
from pydantic import BaseModel
|
|
|
|
from langchain.agents import AgentType, initialize_agent
|
|
from langchain.chat_models import ChatOpenAI
|
|
from langchain.memory import ConversationBufferWindowMemory
|
|
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
|
from langchain.callbacks.streaming_stdout_final_only import FinalStreamingStdOutCallbackHandler
|
|
from langchain.schema import LLMResult
|
|
|
|
app = FastAPI()
|
|
|
|
# initialize the agent (we need to do this for the callbacks)
|
|
llm = ChatOpenAI(
|
|
openai_api_key=os.getenv("OPENAI_API_KEY"),
|
|
temperature=0.0,
|
|
model_name="gpt-3.5-turbo",
|
|
streaming=True, # ! important
|
|
callbacks=[] # ! important (but we will add them later)
|
|
)
|
|
memory = ConversationBufferWindowMemory(
|
|
memory_key="chat_history",
|
|
k=5,
|
|
return_messages=True,
|
|
output_key="output"
|
|
)
|
|
agent = initialize_agent(
|
|
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
|
|
tools=[],
|
|
llm=llm,
|
|
verbose=True,
|
|
max_iterations=3,
|
|
early_stopping_method="generate",
|
|
memory=memory,
|
|
return_intermediate_steps=False
|
|
)
|
|
|
|
class AsyncCallbackHandler(AsyncIteratorCallbackHandler):
|
|
content: str = ""
|
|
final_answer: bool = False
|
|
|
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
self.content += token
|
|
# if we passed the final answer, we put tokens in queue
|
|
if self.final_answer:
|
|
if '"action_input": "' in self.content:
|
|
if token not in ['"', "}"]:
|
|
self.queue.put_nowait(token)
|
|
elif "Final Answer" in self.content:
|
|
self.final_answer = True
|
|
self.content = ""
|
|
|
|
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
if self.final_answer:
|
|
self.content = ""
|
|
self.final_answer = False
|
|
self.done.set()
|
|
else:
|
|
self.content = ""
|
|
|
|
async def run_call(query: str, stream_it: AsyncCallbackHandler):
|
|
# assign callback handler
|
|
agent.agent.llm_chain.llm.callbacks = [stream_it]
|
|
# now query
|
|
await agent.acall(inputs={"input": query})
|
|
|
|
# request input format
|
|
class Query(BaseModel):
|
|
text: str
|
|
|
|
async def create_gen(query: str, stream_it: AsyncCallbackHandler):
|
|
task = asyncio.create_task(run_call(query, stream_it))
|
|
async for token in stream_it.aiter():
|
|
yield token
|
|
await task
|
|
|
|
@app.post("/chat")
|
|
async def chat(
|
|
query: Query = Body(...),
|
|
):
|
|
stream_it = AsyncCallbackHandler()
|
|
gen = create_gen(query.text, stream_it)
|
|
return StreamingResponse(gen, media_type="text/event-stream")
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
"""Check the api is running"""
|
|
return {"status": "🤙"}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(
|
|
"app:app",
|
|
host="localhost",
|
|
port=8000,
|
|
reload=True
|
|
) |