Files
genai-stack-neo4j/bot.py
Oskar Hane af43324d23 Add generation feature to Svelte front-end
Fix a few bugs and refactor generation back-end to chains.py so it can be reused.
2023-10-24 09:47:36 +02:00

177 lines
5.0 KiB
Python

import os
import streamlit as st
from streamlit.logger import get_logger
from langchain.callbacks.base import BaseCallbackHandler
from langchain.graphs import Neo4jGraph
from dotenv import load_dotenv
from utils import (
create_vector_index,
)
from chains import (
load_embedding_model,
load_llm,
configure_llm_only_chain,
configure_qa_rag_chain,
generate_ticket,
)
load_dotenv(".env")
url = os.getenv("NEO4J_URI")
username = os.getenv("NEO4J_USERNAME")
password = os.getenv("NEO4J_PASSWORD")
ollama_base_url = os.getenv("OLLAMA_BASE_URL")
embedding_model_name = os.getenv("EMBEDDING_MODEL")
llm_name = os.getenv("LLM")
# Remapping for Langchain Neo4j integration
os.environ["NEO4J_URL"] = url
logger = get_logger(__name__)
# if Neo4j is local, you can go to http://localhost:7474/ to browse the database
neo4j_graph = Neo4jGraph(url=url, username=username, password=password)
embeddings, dimension = load_embedding_model(
embedding_model_name, config={"ollama_base_url": ollama_base_url}, logger=logger
)
create_vector_index(neo4j_graph, dimension)
class StreamHandler(BaseCallbackHandler):
def __init__(self, container, initial_text=""):
self.container = container
self.text = initial_text
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.text += token
self.container.markdown(self.text)
llm = load_llm(llm_name, logger=logger, config={"ollama_base_url": ollama_base_url})
llm_chain = configure_llm_only_chain(llm)
rag_chain = configure_qa_rag_chain(
llm, embeddings, embeddings_store_url=url, username=username, password=password
)
# Streamlit UI
styl = f"""
<style>
/* not great support for :has yet (hello FireFox), but using it for now */
.element-container:has([aria-label="Select RAG mode"]) {{
position: fixed;
bottom: 33px;
background: white;
z-index: 101;
}}
.stChatFloatingInputContainer {{
bottom: 20px;
}}
/* Generate ticket text area */
textarea[aria-label="Description"] {{
height: 200px;
}}
</style>
"""
st.markdown(styl, unsafe_allow_html=True)
def chat_input():
user_input = st.chat_input("What coding issue can I help you resolve today?")
if user_input:
with st.chat_message("user"):
st.write(user_input)
with st.chat_message("assistant"):
st.caption(f"RAG: {name}")
stream_handler = StreamHandler(st.empty())
result = output_function(
{"question": user_input, "chat_history": []}, callbacks=[stream_handler]
)["answer"]
output = result
st.session_state[f"user_input"].append(user_input)
st.session_state[f"generated"].append(output)
st.session_state[f"rag_mode"].append(name)
def display_chat():
# Session state
if "generated" not in st.session_state:
st.session_state[f"generated"] = []
if "user_input" not in st.session_state:
st.session_state[f"user_input"] = []
if "rag_mode" not in st.session_state:
st.session_state[f"rag_mode"] = []
if st.session_state[f"generated"]:
size = len(st.session_state[f"generated"])
# Display only the last three exchanges
for i in range(max(size - 3, 0), size):
with st.chat_message("user"):
st.write(st.session_state[f"user_input"][i])
with st.chat_message("assistant"):
st.caption(f"RAG: {st.session_state[f'rag_mode'][i]}")
st.write(st.session_state[f"generated"][i])
with st.expander("Not finding what you're looking for?"):
st.write(
"Automatically generate a draft for an internal ticket to our support team."
)
st.button(
"Generate ticket",
type="primary",
key="show_ticket",
on_click=open_sidebar,
)
with st.container():
st.write("&nbsp;")
def mode_select() -> str:
options = ["Disabled", "Enabled"]
return st.radio("Select RAG mode", options, horizontal=True)
name = mode_select()
if name == "LLM only" or name == "Disabled":
output_function = llm_chain
elif name == "Vector + Graph" or name == "Enabled":
output_function = rag_chain
def open_sidebar():
st.session_state.open_sidebar = True
def close_sidebar():
st.session_state.open_sidebar = False
if not "open_sidebar" in st.session_state:
st.session_state.open_sidebar = False
if st.session_state.open_sidebar:
new_title, new_question = generate_ticket(
neo4j_graph=neo4j_graph,
llm_chain=llm_chain,
input_question=st.session_state[f"user_input"][-1],
)
with st.sidebar:
st.title("Ticket draft")
st.write("Auto generated draft ticket")
st.text_input("Title", new_title)
st.text_area("Description", new_question)
st.button(
"Submit to support team",
type="primary",
key="submit_ticket",
on_click=close_sidebar,
)
display_chat()
chat_input()