mirror of
https://github.com/docker/genai-stack.git
synced 2024-08-30 16:49:54 +03:00
Refactor logic into separate files for easier reading
This commit is contained in:
@@ -13,9 +13,9 @@ COPY requirements.txt .
|
||||
|
||||
RUN pip install --upgrade -r requirements.txt
|
||||
|
||||
# COPY .env .
|
||||
COPY bot.py .
|
||||
COPY utils.py .
|
||||
COPY chains.py .
|
||||
|
||||
EXPOSE 8501
|
||||
|
||||
|
||||
181
bot.py
181
bot.py
@@ -1,14 +1,8 @@
|
||||
import os
|
||||
from typing import List, Any
|
||||
|
||||
import streamlit as st
|
||||
from streamlit.logger import get_logger
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.vectorstores.neo4j_vector import Neo4jVector
|
||||
|
||||
from langchain.chat_models import ChatOpenAI, ChatOllama
|
||||
from langchain.chains import RetrievalQAWithSourcesChain
|
||||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
@@ -16,7 +10,16 @@ from langchain.prompts.chat import (
|
||||
)
|
||||
from langchain.graphs import Neo4jGraph
|
||||
from dotenv import load_dotenv
|
||||
from utils import extract_title_and_question, load_embedding_model
|
||||
from utils import (
|
||||
extract_title_and_question,
|
||||
create_vector_index,
|
||||
)
|
||||
from chains import (
|
||||
load_embedding_model,
|
||||
load_llm,
|
||||
configure_llm_only_chain,
|
||||
configure_qa_rag_chain,
|
||||
)
|
||||
|
||||
load_dotenv(".env")
|
||||
|
||||
@@ -33,19 +36,10 @@ logger = get_logger(__name__)
|
||||
|
||||
# if Neo4j is local, you can go to http://localhost:7474/ to browse the database
|
||||
neo4j_graph = Neo4jGraph(url=url, username=username, password=password)
|
||||
|
||||
|
||||
def create_vector_index(dimension: int) -> None:
|
||||
index_query = "CALL db.index.vector.createNodeIndex('stackoverflow', 'Question', 'embedding', $dimension, 'cosine')"
|
||||
try:
|
||||
neo4j_graph.query(index_query, {"dimension": dimension})
|
||||
except: # Already exists
|
||||
pass
|
||||
index_query = "CALL db.index.vector.createNodeIndex('top_answers', 'Answer', 'embedding', $dimension, 'cosine')"
|
||||
try:
|
||||
neo4j_graph.query(index_query, {"dimension": dimension})
|
||||
except: # Already exists
|
||||
pass
|
||||
embeddings, dimension = load_embedding_model(
|
||||
embedding_model_name, config={ollama_base_url: ollama_base_url}, logger=logger
|
||||
)
|
||||
create_vector_index(neo4j_graph, dimension)
|
||||
|
||||
|
||||
class StreamHandler(BaseCallbackHandler):
|
||||
@@ -58,142 +52,11 @@ class StreamHandler(BaseCallbackHandler):
|
||||
self.container.markdown(self.text)
|
||||
|
||||
|
||||
embeddings, dimension = load_embedding_model(
|
||||
embedding_model_name, config={ollama_base_url: ollama_base_url}, logger=logger
|
||||
)
|
||||
llm = load_llm(llm_name, logger=logger, config={"ollama_base_url": ollama_base_url})
|
||||
|
||||
create_vector_index(dimension)
|
||||
|
||||
if llm_name == "gpt-4":
|
||||
llm = ChatOpenAI(temperature=0, model_name="gpt-4", streaming=True)
|
||||
logger.info("LLM: Using GPT-4")
|
||||
elif llm_name == "gpt-3.5":
|
||||
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
|
||||
logger.info("LLM: Using GPT-3.5 Turbo")
|
||||
elif len(llm_name):
|
||||
llm = ChatOllama(
|
||||
temperature=0,
|
||||
base_url=ollama_base_url,
|
||||
model=llm_name,
|
||||
streaming=True,
|
||||
top_k=10, # A higher value (100) will give more diverse answers, while a lower value (10) will be more conservative.
|
||||
top_p=0.3, # Higher value (0.95) will lead to more diverse text, while a lower value (0.5) will generate more focused text.
|
||||
num_ctx=3072, # Sets the size of the context window used to generate the next token.
|
||||
)
|
||||
logger.info(f"LLM: Using Ollama ({llm_name})")
|
||||
else:
|
||||
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
|
||||
logger.info("LLM: Using GPT-3.5 Turbo")
|
||||
|
||||
# LLM only response
|
||||
template = """
|
||||
You are a helpful assistant that helps a support agent with answering programming questions.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
"""
|
||||
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
|
||||
human_template = "{text}"
|
||||
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[system_message_prompt, human_message_prompt]
|
||||
)
|
||||
|
||||
|
||||
def generate_llm_output(
|
||||
user_input: str, callbacks: List[Any], prompt=chat_prompt
|
||||
) -> str:
|
||||
answer = llm(
|
||||
prompt.format_prompt(
|
||||
text=user_input,
|
||||
).to_messages(),
|
||||
callbacks=callbacks,
|
||||
).content
|
||||
return {"answer": answer}
|
||||
|
||||
|
||||
# Vector response
|
||||
neo4j_db = Neo4jVector.from_existing_index(
|
||||
embedding=embeddings,
|
||||
url=url,
|
||||
username=username,
|
||||
password=password,
|
||||
database="neo4j", # neo4j by default
|
||||
index_name="top_answers", # vector by default
|
||||
text_node_property="body", # text by default
|
||||
retrieval_query="""
|
||||
OPTIONAL MATCH (node)-[:ANSWERS]->(question)
|
||||
RETURN 'Question: ' + question.title + '\n' + question.body + '\nAnswer: ' +
|
||||
coalesce(node.body,"") AS text, score, {source:question.link} AS metadata
|
||||
ORDER BY score ASC // so that best answer are the last
|
||||
""",
|
||||
)
|
||||
|
||||
general_system_template = """
|
||||
Use the following pieces of context to answer the question at the end.
|
||||
The context contains question-answer pairs and their links from Stackoverflow.
|
||||
You should prefer information from accepted or more upvoted answers.
|
||||
Make sure to rely on information from the answers and not on questions to provide accuate responses.
|
||||
When you find particular answer in the context useful, make sure to cite it in the answer using the link.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
----
|
||||
{summaries}
|
||||
----
|
||||
Each answer you generate should contain a section at the end of links to
|
||||
Stackoverflow questions and answers you found useful, which are described under Source value.
|
||||
You can only use links to StackOverflow questions that are present in the context and always
|
||||
add links to the end of the answer in the style of citations.
|
||||
Generate concise answers with references sources section of links to
|
||||
relevant StackOverflow questions only at the end of the answer.
|
||||
"""
|
||||
general_user_template = "Question:```{question}```"
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(general_system_template),
|
||||
HumanMessagePromptTemplate.from_template(general_user_template),
|
||||
]
|
||||
qa_prompt = ChatPromptTemplate.from_messages(messages)
|
||||
|
||||
qa_chain = load_qa_with_sources_chain(
|
||||
llm,
|
||||
chain_type="stuff",
|
||||
prompt=qa_prompt,
|
||||
)
|
||||
qa = RetrievalQAWithSourcesChain(
|
||||
combine_documents_chain=qa_chain,
|
||||
retriever=neo4j_db.as_retriever(search_kwargs={"k": 2}),
|
||||
reduce_k_below_max_tokens=True,
|
||||
max_tokens_limit=3375,
|
||||
)
|
||||
|
||||
# Vector + Knowledge Graph response
|
||||
kg = Neo4jVector.from_existing_index(
|
||||
embedding=embeddings,
|
||||
url=url,
|
||||
username=username,
|
||||
password=password,
|
||||
database="neo4j", # neo4j by default
|
||||
index_name="stackoverflow", # vector by default
|
||||
text_node_property="body", # text by default
|
||||
retrieval_query="""
|
||||
WITH node AS question, score AS similarity
|
||||
CALL { with question
|
||||
MATCH (question)<-[:ANSWERS]-(answer)
|
||||
WITH answer
|
||||
ORDER BY answer.is_accepted DESC, answer.score DESC
|
||||
WITH collect(answer)[..2] as answers
|
||||
RETURN reduce(str='', answer IN answers | str +
|
||||
'\n### Answer (Accepted: '+ answer.is_accepted +
|
||||
' Score: ' + answer.score+ '): '+ answer.body + '\n') as answerTexts
|
||||
}
|
||||
RETURN '##Question: ' + question.title + '\n' + question.body + '\n'
|
||||
+ answerTexts AS text, similarity as score, {source: question.link} AS metadata
|
||||
ORDER BY similarity ASC // so that best answers are the last
|
||||
""",
|
||||
)
|
||||
|
||||
kg_qa = RetrievalQAWithSourcesChain(
|
||||
combine_documents_chain=qa_chain,
|
||||
retriever=kg.as_retriever(search_kwargs={"k": 2}),
|
||||
reduce_k_below_max_tokens=False,
|
||||
max_tokens_limit=3375,
|
||||
llm_chain = configure_llm_only_chain(llm)
|
||||
rag_chain = configure_qa_rag_chain(
|
||||
llm, embeddings, embeddings_store_url=url, username=username, password=password
|
||||
)
|
||||
|
||||
# Streamlit UI
|
||||
@@ -280,11 +143,9 @@ def mode_select() -> str:
|
||||
|
||||
name = mode_select()
|
||||
if name == "LLM only" or name == "Disabled":
|
||||
output_function = generate_llm_output
|
||||
elif name == "Vector":
|
||||
output_function = qa
|
||||
output_function = llm_chain
|
||||
elif name == "Vector + Graph" or name == "Enabled":
|
||||
output_function = kg_qa
|
||||
output_function = rag_chain
|
||||
|
||||
|
||||
def generate_ticket():
|
||||
@@ -337,7 +198,7 @@ def generate_ticket():
|
||||
HumanMessagePromptTemplate.from_template("{text}"),
|
||||
]
|
||||
)
|
||||
llm_response = generate_llm_output(
|
||||
llm_response = llm_chain(
|
||||
f"Here's the question to rewrite in the expected format: ```{q_prompt}```",
|
||||
[],
|
||||
chat_prompt,
|
||||
|
||||
147
chains.py
Normal file
147
chains.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.embeddings import OllamaEmbeddings, SentenceTransformerEmbeddings
|
||||
from langchain.chat_models import ChatOpenAI, ChatOllama
|
||||
from langchain.vectorstores.neo4j_vector import Neo4jVector
|
||||
from langchain.chains import RetrievalQAWithSourcesChain
|
||||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
from typing import List, Any
|
||||
from utils import BaseLogger
|
||||
|
||||
|
||||
def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}):
|
||||
if embedding_model_name == "ollama":
|
||||
embeddings = OllamaEmbeddings(base_url=config.ollama_base_url, model="llama2")
|
||||
dimension = 4096
|
||||
logger.info("Embedding: Using Ollama")
|
||||
elif embedding_model_name == "openai":
|
||||
embeddings = OpenAIEmbeddings()
|
||||
dimension = 1536
|
||||
logger.info("Embedding: Using OpenAI")
|
||||
else:
|
||||
embeddings = SentenceTransformerEmbeddings(
|
||||
model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
|
||||
)
|
||||
dimension = 384
|
||||
logger.info("Embedding: Using SentenceTransformer")
|
||||
return embeddings, dimension
|
||||
|
||||
|
||||
def load_llm(llm_name: str, logger=BaseLogger(), config={}):
|
||||
if llm_name == "gpt-4":
|
||||
logger.info("LLM: Using GPT-4")
|
||||
return ChatOpenAI(temperature=0, model_name="gpt-4", streaming=True)
|
||||
elif llm_name == "gpt-3.5":
|
||||
logger.info("LLM: Using GPT-3.5")
|
||||
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
|
||||
elif len(llm_name):
|
||||
logger.info(f"LLM: Using Ollama: {llm_name}")
|
||||
return ChatOllama(
|
||||
temperature=0,
|
||||
base_url=config["ollama_base_url"],
|
||||
model=llm_name,
|
||||
streaming=True,
|
||||
top_k=10, # A higher value (100) will give more diverse answers, while a lower value (10) will be more conservative.
|
||||
top_p=0.3, # Higher value (0.95) will lead to more diverse text, while a lower value (0.5) will generate more focused text.
|
||||
num_ctx=3072, # Sets the size of the context window used to generate the next token.
|
||||
)
|
||||
logger.info("LLM: Using GPT-3.5")
|
||||
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
|
||||
|
||||
|
||||
def configure_llm_only_chain(llm):
|
||||
# LLM only response
|
||||
template = """
|
||||
You are a helpful assistant that helps a support agent with answering programming questions.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
"""
|
||||
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
|
||||
human_template = "{text}"
|
||||
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[system_message_prompt, human_message_prompt]
|
||||
)
|
||||
|
||||
def generate_llm_output(
|
||||
user_input: str, callbacks: List[Any], prompt=chat_prompt
|
||||
) -> str:
|
||||
answer = llm(
|
||||
prompt.format_prompt(
|
||||
text=user_input,
|
||||
).to_messages(),
|
||||
callbacks=callbacks,
|
||||
).content
|
||||
return {"answer": answer}
|
||||
|
||||
return generate_llm_output
|
||||
|
||||
|
||||
def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, password):
|
||||
# RAG response
|
||||
general_system_template = """
|
||||
Use the following pieces of context to answer the question at the end.
|
||||
The context contains question-answer pairs and their links from Stackoverflow.
|
||||
You should prefer information from accepted or more upvoted answers.
|
||||
Make sure to rely on information from the answers and not on questions to provide accuate responses.
|
||||
When you find particular answer in the context useful, make sure to cite it in the answer using the link.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
----
|
||||
{summaries}
|
||||
----
|
||||
Each answer you generate should contain a section at the end of links to
|
||||
Stackoverflow questions and answers you found useful, which are described under Source value.
|
||||
You can only use links to StackOverflow questions that are present in the context and always
|
||||
add links to the end of the answer in the style of citations.
|
||||
Generate concise answers with references sources section of links to
|
||||
relevant StackOverflow questions only at the end of the answer.
|
||||
"""
|
||||
general_user_template = "Question:```{question}```"
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(general_system_template),
|
||||
HumanMessagePromptTemplate.from_template(general_user_template),
|
||||
]
|
||||
qa_prompt = ChatPromptTemplate.from_messages(messages)
|
||||
|
||||
qa_chain = load_qa_with_sources_chain(
|
||||
llm,
|
||||
chain_type="stuff",
|
||||
prompt=qa_prompt,
|
||||
)
|
||||
|
||||
# Vector + Knowledge Graph response
|
||||
kg = Neo4jVector.from_existing_index(
|
||||
embedding=embeddings,
|
||||
url=embeddings_store_url,
|
||||
username=username,
|
||||
password=password,
|
||||
database="neo4j", # neo4j by default
|
||||
index_name="stackoverflow", # vector by default
|
||||
text_node_property="body", # text by default
|
||||
retrieval_query="""
|
||||
WITH node AS question, score AS similarity
|
||||
CALL { with question
|
||||
MATCH (question)<-[:ANSWERS]-(answer)
|
||||
WITH answer
|
||||
ORDER BY answer.is_accepted DESC, answer.score DESC
|
||||
WITH collect(answer)[..2] as answers
|
||||
RETURN reduce(str='', answer IN answers | str +
|
||||
'\n### Answer (Accepted: '+ answer.is_accepted +
|
||||
' Score: ' + answer.score+ '): '+ answer.body + '\n') as answerTexts
|
||||
}
|
||||
RETURN '##Question: ' + question.title + '\n' + question.body + '\n'
|
||||
+ answerTexts AS text, similarity as score, {source: question.link} AS metadata
|
||||
ORDER BY similarity ASC // so that best answers are the last
|
||||
""",
|
||||
)
|
||||
|
||||
kg_qa = RetrievalQAWithSourcesChain(
|
||||
combine_documents_chain=qa_chain,
|
||||
retriever=kg.as_retriever(search_kwargs={"k": 2}),
|
||||
reduce_k_below_max_tokens=False,
|
||||
max_tokens_limit=3375,
|
||||
)
|
||||
return kg_qa
|
||||
@@ -14,6 +14,7 @@ RUN pip install --upgrade -r requirements.txt
|
||||
|
||||
COPY loader.py .
|
||||
COPY utils.py .
|
||||
COPY chains.py .
|
||||
COPY images ./images
|
||||
|
||||
EXPOSE 8502
|
||||
|
||||
38
loader.py
38
loader.py
@@ -4,7 +4,8 @@ from dotenv import load_dotenv
|
||||
from langchain.graphs import Neo4jGraph
|
||||
import streamlit as st
|
||||
from streamlit.logger import get_logger
|
||||
from utils import load_embedding_model
|
||||
from chains import load_embedding_model
|
||||
from utils import create_constraints, create_vector_index
|
||||
from PIL import Image
|
||||
|
||||
load_dotenv(".env")
|
||||
@@ -28,39 +29,8 @@ embeddings, dimension = load_embedding_model(
|
||||
# if Neo4j is local, you can go to http://localhost:7474/ to browse the database
|
||||
neo4j_graph = Neo4jGraph(url=url, username=username, password=password)
|
||||
|
||||
|
||||
def create_constraints():
|
||||
neo4j_graph.query(
|
||||
"CREATE CONSTRAINT question_id IF NOT EXISTS FOR (q:Question) REQUIRE (q.id) IS UNIQUE"
|
||||
)
|
||||
neo4j_graph.query(
|
||||
"CREATE CONSTRAINT answer_id IF NOT EXISTS FOR (a:Answer) REQUIRE (a.id) IS UNIQUE"
|
||||
)
|
||||
neo4j_graph.query(
|
||||
"CREATE CONSTRAINT user_id IF NOT EXISTS FOR (u:User) REQUIRE (u.id) IS UNIQUE"
|
||||
)
|
||||
neo4j_graph.query(
|
||||
"CREATE CONSTRAINT tag_name IF NOT EXISTS FOR (t:Tag) REQUIRE (t.name) IS UNIQUE"
|
||||
)
|
||||
|
||||
|
||||
create_constraints()
|
||||
|
||||
|
||||
def create_vector_index(dimension):
|
||||
index_query = "CALL db.index.vector.createNodeIndex('stackoverflow', 'Question', 'embedding', $dimension, 'cosine')"
|
||||
try:
|
||||
neo4j_graph.query(index_query, {"dimension": dimension})
|
||||
except: # Already exists
|
||||
pass
|
||||
index_query = "CALL db.index.vector.createNodeIndex('top_answers', 'Answer', 'embedding', $dimension, 'cosine')"
|
||||
try:
|
||||
neo4j_graph.query(index_query, {"dimension": dimension})
|
||||
except: # Already exists
|
||||
pass
|
||||
|
||||
|
||||
create_vector_index(dimension)
|
||||
create_constraints(neo4j_graph)
|
||||
create_vector_index(neo4j_graph, dimension)
|
||||
|
||||
|
||||
def load_so_data(tag: str = "neo4j", page: int = 1) -> None:
|
||||
|
||||
51
utils.py
51
utils.py
@@ -1,23 +1,6 @@
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.embeddings import OllamaEmbeddings, SentenceTransformerEmbeddings
|
||||
|
||||
|
||||
def load_embedding_model(embedding_model_name: str, config={}, logger=print):
|
||||
if embedding_model_name == "ollama":
|
||||
embeddings = OllamaEmbeddings(base_url=config.ollama_base_url, model="llama2")
|
||||
dimension = 4096
|
||||
logger.info("Embedding: Using Ollama")
|
||||
elif embedding_model_name == "openai":
|
||||
embeddings = OpenAIEmbeddings()
|
||||
dimension = 1536
|
||||
logger.info("Embedding: Using OpenAI")
|
||||
else:
|
||||
embeddings = SentenceTransformerEmbeddings(
|
||||
model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
|
||||
)
|
||||
dimension = 384
|
||||
logger.info("Embedding: Using SentenceTransformer")
|
||||
return embeddings, dimension
|
||||
class BaseLogger:
|
||||
def __init__(self) -> None:
|
||||
self.info = print
|
||||
|
||||
|
||||
def extract_title_and_question(input_string):
|
||||
@@ -41,3 +24,31 @@ def extract_title_and_question(input_string):
|
||||
question += "\n" + line.strip()
|
||||
|
||||
return title, question
|
||||
|
||||
|
||||
def create_vector_index(driver, dimension: int) -> None:
|
||||
index_query = "CALL db.index.vector.createNodeIndex('stackoverflow', 'Question', 'embedding', $dimension, 'cosine')"
|
||||
try:
|
||||
driver.query(index_query, {"dimension": dimension})
|
||||
except: # Already exists
|
||||
pass
|
||||
index_query = "CALL db.index.vector.createNodeIndex('top_answers', 'Answer', 'embedding', $dimension, 'cosine')"
|
||||
try:
|
||||
driver.query(index_query, {"dimension": dimension})
|
||||
except: # Already exists
|
||||
pass
|
||||
|
||||
|
||||
def create_constraints(driver):
|
||||
driver.query(
|
||||
"CREATE CONSTRAINT question_id IF NOT EXISTS FOR (q:Question) REQUIRE (q.id) IS UNIQUE"
|
||||
)
|
||||
driver.query(
|
||||
"CREATE CONSTRAINT answer_id IF NOT EXISTS FOR (a:Answer) REQUIRE (a.id) IS UNIQUE"
|
||||
)
|
||||
driver.query(
|
||||
"CREATE CONSTRAINT user_id IF NOT EXISTS FOR (u:User) REQUIRE (u.id) IS UNIQUE"
|
||||
)
|
||||
driver.query(
|
||||
"CREATE CONSTRAINT tag_name IF NOT EXISTS FOR (t:Tag) REQUIRE (t.name) IS UNIQUE"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user