mirror of
https://github.com/docker/genai-stack.git
synced 2024-08-30 16:49:54 +03:00
177 lines
5.0 KiB
Python
177 lines
5.0 KiB
Python
import os
|
|
|
|
import streamlit as st
|
|
from streamlit.logger import get_logger
|
|
from langchain.callbacks.base import BaseCallbackHandler
|
|
from langchain.graphs import Neo4jGraph
|
|
from dotenv import load_dotenv
|
|
from utils import (
|
|
create_vector_index,
|
|
)
|
|
from chains import (
|
|
load_embedding_model,
|
|
load_llm,
|
|
configure_llm_only_chain,
|
|
configure_qa_rag_chain,
|
|
generate_ticket,
|
|
)
|
|
|
|
load_dotenv(".env")
|
|
|
|
url = os.getenv("NEO4J_URI")
|
|
username = os.getenv("NEO4J_USERNAME")
|
|
password = os.getenv("NEO4J_PASSWORD")
|
|
ollama_base_url = os.getenv("OLLAMA_BASE_URL")
|
|
embedding_model_name = os.getenv("EMBEDDING_MODEL")
|
|
llm_name = os.getenv("LLM")
|
|
# Remapping for Langchain Neo4j integration
|
|
os.environ["NEO4J_URL"] = url
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# if Neo4j is local, you can go to http://localhost:7474/ to browse the database
|
|
neo4j_graph = Neo4jGraph(url=url, username=username, password=password)
|
|
embeddings, dimension = load_embedding_model(
|
|
embedding_model_name, config={"ollama_base_url": ollama_base_url}, logger=logger
|
|
)
|
|
create_vector_index(neo4j_graph, dimension)
|
|
|
|
|
|
class StreamHandler(BaseCallbackHandler):
|
|
def __init__(self, container, initial_text=""):
|
|
self.container = container
|
|
self.text = initial_text
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
|
self.text += token
|
|
self.container.markdown(self.text)
|
|
|
|
|
|
llm = load_llm(llm_name, logger=logger, config={"ollama_base_url": ollama_base_url})
|
|
|
|
llm_chain = configure_llm_only_chain(llm)
|
|
rag_chain = configure_qa_rag_chain(
|
|
llm, embeddings, embeddings_store_url=url, username=username, password=password
|
|
)
|
|
|
|
# Streamlit UI
|
|
styl = f"""
|
|
<style>
|
|
/* not great support for :has yet (hello FireFox), but using it for now */
|
|
.element-container:has([aria-label="Select RAG mode"]) {{
|
|
position: fixed;
|
|
bottom: 33px;
|
|
background: white;
|
|
z-index: 101;
|
|
}}
|
|
.stChatFloatingInputContainer {{
|
|
bottom: 20px;
|
|
}}
|
|
|
|
/* Generate ticket text area */
|
|
textarea[aria-label="Description"] {{
|
|
height: 200px;
|
|
}}
|
|
</style>
|
|
"""
|
|
st.markdown(styl, unsafe_allow_html=True)
|
|
|
|
|
|
def chat_input():
|
|
user_input = st.chat_input("What coding issue can I help you resolve today?")
|
|
|
|
if user_input:
|
|
with st.chat_message("user"):
|
|
st.write(user_input)
|
|
with st.chat_message("assistant"):
|
|
st.caption(f"RAG: {name}")
|
|
stream_handler = StreamHandler(st.empty())
|
|
result = output_function(
|
|
{"question": user_input, "chat_history": []}, callbacks=[stream_handler]
|
|
)["answer"]
|
|
output = result
|
|
st.session_state[f"user_input"].append(user_input)
|
|
st.session_state[f"generated"].append(output)
|
|
st.session_state[f"rag_mode"].append(name)
|
|
|
|
|
|
def display_chat():
|
|
# Session state
|
|
if "generated" not in st.session_state:
|
|
st.session_state[f"generated"] = []
|
|
|
|
if "user_input" not in st.session_state:
|
|
st.session_state[f"user_input"] = []
|
|
|
|
if "rag_mode" not in st.session_state:
|
|
st.session_state[f"rag_mode"] = []
|
|
|
|
if st.session_state[f"generated"]:
|
|
size = len(st.session_state[f"generated"])
|
|
# Display only the last three exchanges
|
|
for i in range(max(size - 3, 0), size):
|
|
with st.chat_message("user"):
|
|
st.write(st.session_state[f"user_input"][i])
|
|
|
|
with st.chat_message("assistant"):
|
|
st.caption(f"RAG: {st.session_state[f'rag_mode'][i]}")
|
|
st.write(st.session_state[f"generated"][i])
|
|
|
|
with st.expander("Not finding what you're looking for?"):
|
|
st.write(
|
|
"Automatically generate a draft for an internal ticket to our support team."
|
|
)
|
|
st.button(
|
|
"Generate ticket",
|
|
type="primary",
|
|
key="show_ticket",
|
|
on_click=open_sidebar,
|
|
)
|
|
with st.container():
|
|
st.write(" ")
|
|
|
|
|
|
def mode_select() -> str:
|
|
options = ["Disabled", "Enabled"]
|
|
return st.radio("Select RAG mode", options, horizontal=True)
|
|
|
|
|
|
name = mode_select()
|
|
if name == "LLM only" or name == "Disabled":
|
|
output_function = llm_chain
|
|
elif name == "Vector + Graph" or name == "Enabled":
|
|
output_function = rag_chain
|
|
|
|
|
|
def open_sidebar():
|
|
st.session_state.open_sidebar = True
|
|
|
|
|
|
def close_sidebar():
|
|
st.session_state.open_sidebar = False
|
|
|
|
|
|
if not "open_sidebar" in st.session_state:
|
|
st.session_state.open_sidebar = False
|
|
if st.session_state.open_sidebar:
|
|
new_title, new_question = generate_ticket(
|
|
neo4j_graph=neo4j_graph,
|
|
llm_chain=llm_chain,
|
|
input_question=st.session_state[f"user_input"][-1],
|
|
)
|
|
with st.sidebar:
|
|
st.title("Ticket draft")
|
|
st.write("Auto generated draft ticket")
|
|
st.text_input("Title", new_title)
|
|
st.text_area("Description", new_question)
|
|
st.button(
|
|
"Submit to support team",
|
|
type="primary",
|
|
key="submit_ticket",
|
|
on_click=close_sidebar,
|
|
)
|
|
|
|
|
|
display_chat()
|
|
chat_input()
|