Files
llm-examples/learn/generation/langchain/handbook/09-langchain-streaming/main.py
2023-09-28 11:48:59 +02:00

104 lines
2.9 KiB
Python

import os
import asyncio
from typing import Any
import uvicorn
from fastapi import FastAPI, Body
from fastapi.responses import StreamingResponse
from queue import Queue
from pydantic import BaseModel
from langchain.agents import AgentType, initialize_agent
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferWindowMemory
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.callbacks.streaming_stdout_final_only import FinalStreamingStdOutCallbackHandler
from langchain.schema import LLMResult
app = FastAPI()
# initialize the agent (we need to do this for the callbacks)
llm = ChatOpenAI(
openai_api_key=os.getenv("OPENAI_API_KEY"),
temperature=0.0,
model_name="gpt-3.5-turbo",
streaming=True, # ! important
callbacks=[] # ! important (but we will add them later)
)
memory = ConversationBufferWindowMemory(
memory_key="chat_history",
k=5,
return_messages=True,
output_key="output"
)
agent = initialize_agent(
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
tools=[],
llm=llm,
verbose=True,
max_iterations=3,
early_stopping_method="generate",
memory=memory,
return_intermediate_steps=False
)
class AsyncCallbackHandler(AsyncIteratorCallbackHandler):
content: str = ""
final_answer: bool = False
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self.content += token
# if we passed the final answer, we put tokens in queue
if self.final_answer:
if '"action_input": "' in self.content:
if token not in ['"', "}"]:
self.queue.put_nowait(token)
elif "Final Answer" in self.content:
self.final_answer = True
self.content = ""
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
if self.final_answer:
self.content = ""
self.final_answer = False
self.done.set()
else:
self.content = ""
async def run_call(query: str, stream_it: AsyncCallbackHandler):
# assign callback handler
agent.agent.llm_chain.llm.callbacks = [stream_it]
# now query
await agent.acall(inputs={"input": query})
# request input format
class Query(BaseModel):
text: str
async def create_gen(query: str, stream_it: AsyncCallbackHandler):
task = asyncio.create_task(run_call(query, stream_it))
async for token in stream_it.aiter():
yield token
await task
@app.post("/chat")
async def chat(
query: Query = Body(...),
):
stream_it = AsyncCallbackHandler()
gen = create_gen(query.text, stream_it)
return StreamingResponse(gen, media_type="text/event-stream")
@app.get("/health")
async def health():
"""Check the api is running"""
return {"status": "🤙"}
if __name__ == "__main__":
uvicorn.run(
"app:app",
host="localhost",
port=8000,
reload=True
)