Refactor logic into separate files for easier reading

This commit is contained in:
Oskar Hane
2023-10-02 12:57:45 +02:00
parent 325a026829
commit 48be0ca062
6 changed files with 205 additions and 215 deletions

View File

@@ -13,9 +13,9 @@ COPY requirements.txt .
RUN pip install --upgrade -r requirements.txt
# COPY .env .
COPY bot.py .
COPY utils.py .
COPY chains.py .
EXPOSE 8501

181
bot.py
View File

@@ -1,14 +1,8 @@
import os
from typing import List, Any
import streamlit as st
from streamlit.logger import get_logger
from langchain.callbacks.base import BaseCallbackHandler
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.chat_models import ChatOpenAI, ChatOllama
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
@@ -16,7 +10,16 @@ from langchain.prompts.chat import (
)
from langchain.graphs import Neo4jGraph
from dotenv import load_dotenv
from utils import extract_title_and_question, load_embedding_model
from utils import (
extract_title_and_question,
create_vector_index,
)
from chains import (
load_embedding_model,
load_llm,
configure_llm_only_chain,
configure_qa_rag_chain,
)
load_dotenv(".env")
@@ -33,19 +36,10 @@ logger = get_logger(__name__)
# if Neo4j is local, you can go to http://localhost:7474/ to browse the database
neo4j_graph = Neo4jGraph(url=url, username=username, password=password)
def create_vector_index(dimension: int) -> None:
index_query = "CALL db.index.vector.createNodeIndex('stackoverflow', 'Question', 'embedding', $dimension, 'cosine')"
try:
neo4j_graph.query(index_query, {"dimension": dimension})
except: # Already exists
pass
index_query = "CALL db.index.vector.createNodeIndex('top_answers', 'Answer', 'embedding', $dimension, 'cosine')"
try:
neo4j_graph.query(index_query, {"dimension": dimension})
except: # Already exists
pass
embeddings, dimension = load_embedding_model(
embedding_model_name, config={ollama_base_url: ollama_base_url}, logger=logger
)
create_vector_index(neo4j_graph, dimension)
class StreamHandler(BaseCallbackHandler):
@@ -58,142 +52,11 @@ class StreamHandler(BaseCallbackHandler):
self.container.markdown(self.text)
embeddings, dimension = load_embedding_model(
embedding_model_name, config={ollama_base_url: ollama_base_url}, logger=logger
)
llm = load_llm(llm_name, logger=logger, config={"ollama_base_url": ollama_base_url})
create_vector_index(dimension)
if llm_name == "gpt-4":
llm = ChatOpenAI(temperature=0, model_name="gpt-4", streaming=True)
logger.info("LLM: Using GPT-4")
elif llm_name == "gpt-3.5":
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
logger.info("LLM: Using GPT-3.5 Turbo")
elif len(llm_name):
llm = ChatOllama(
temperature=0,
base_url=ollama_base_url,
model=llm_name,
streaming=True,
top_k=10, # A higher value (100) will give more diverse answers, while a lower value (10) will be more conservative.
top_p=0.3, # Higher value (0.95) will lead to more diverse text, while a lower value (0.5) will generate more focused text.
num_ctx=3072, # Sets the size of the context window used to generate the next token.
)
logger.info(f"LLM: Using Ollama ({llm_name})")
else:
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
logger.info("LLM: Using GPT-3.5 Turbo")
# LLM only response
template = """
You are a helpful assistant that helps a support agent with answering programming questions.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
"""
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
human_template = "{text}"
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
chat_prompt = ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)
def generate_llm_output(
user_input: str, callbacks: List[Any], prompt=chat_prompt
) -> str:
answer = llm(
prompt.format_prompt(
text=user_input,
).to_messages(),
callbacks=callbacks,
).content
return {"answer": answer}
# Vector response
neo4j_db = Neo4jVector.from_existing_index(
embedding=embeddings,
url=url,
username=username,
password=password,
database="neo4j", # neo4j by default
index_name="top_answers", # vector by default
text_node_property="body", # text by default
retrieval_query="""
OPTIONAL MATCH (node)-[:ANSWERS]->(question)
RETURN 'Question: ' + question.title + '\n' + question.body + '\nAnswer: ' +
coalesce(node.body,"") AS text, score, {source:question.link} AS metadata
ORDER BY score ASC // so that best answer are the last
""",
)
general_system_template = """
Use the following pieces of context to answer the question at the end.
The context contains question-answer pairs and their links from Stackoverflow.
You should prefer information from accepted or more upvoted answers.
Make sure to rely on information from the answers and not on questions to provide accuate responses.
When you find particular answer in the context useful, make sure to cite it in the answer using the link.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
----
{summaries}
----
Each answer you generate should contain a section at the end of links to
Stackoverflow questions and answers you found useful, which are described under Source value.
You can only use links to StackOverflow questions that are present in the context and always
add links to the end of the answer in the style of citations.
Generate concise answers with references sources section of links to
relevant StackOverflow questions only at the end of the answer.
"""
general_user_template = "Question:```{question}```"
messages = [
SystemMessagePromptTemplate.from_template(general_system_template),
HumanMessagePromptTemplate.from_template(general_user_template),
]
qa_prompt = ChatPromptTemplate.from_messages(messages)
qa_chain = load_qa_with_sources_chain(
llm,
chain_type="stuff",
prompt=qa_prompt,
)
qa = RetrievalQAWithSourcesChain(
combine_documents_chain=qa_chain,
retriever=neo4j_db.as_retriever(search_kwargs={"k": 2}),
reduce_k_below_max_tokens=True,
max_tokens_limit=3375,
)
# Vector + Knowledge Graph response
kg = Neo4jVector.from_existing_index(
embedding=embeddings,
url=url,
username=username,
password=password,
database="neo4j", # neo4j by default
index_name="stackoverflow", # vector by default
text_node_property="body", # text by default
retrieval_query="""
WITH node AS question, score AS similarity
CALL { with question
MATCH (question)<-[:ANSWERS]-(answer)
WITH answer
ORDER BY answer.is_accepted DESC, answer.score DESC
WITH collect(answer)[..2] as answers
RETURN reduce(str='', answer IN answers | str +
'\n### Answer (Accepted: '+ answer.is_accepted +
' Score: ' + answer.score+ '): '+ answer.body + '\n') as answerTexts
}
RETURN '##Question: ' + question.title + '\n' + question.body + '\n'
+ answerTexts AS text, similarity as score, {source: question.link} AS metadata
ORDER BY similarity ASC // so that best answers are the last
""",
)
kg_qa = RetrievalQAWithSourcesChain(
combine_documents_chain=qa_chain,
retriever=kg.as_retriever(search_kwargs={"k": 2}),
reduce_k_below_max_tokens=False,
max_tokens_limit=3375,
llm_chain = configure_llm_only_chain(llm)
rag_chain = configure_qa_rag_chain(
llm, embeddings, embeddings_store_url=url, username=username, password=password
)
# Streamlit UI
@@ -280,11 +143,9 @@ def mode_select() -> str:
name = mode_select()
if name == "LLM only" or name == "Disabled":
output_function = generate_llm_output
elif name == "Vector":
output_function = qa
output_function = llm_chain
elif name == "Vector + Graph" or name == "Enabled":
output_function = kg_qa
output_function = rag_chain
def generate_ticket():
@@ -337,7 +198,7 @@ def generate_ticket():
HumanMessagePromptTemplate.from_template("{text}"),
]
)
llm_response = generate_llm_output(
llm_response = llm_chain(
f"Here's the question to rewrite in the expected format: ```{q_prompt}```",
[],
chat_prompt,

147
chains.py Normal file
View File

@@ -0,0 +1,147 @@
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import OllamaEmbeddings, SentenceTransformerEmbeddings
from langchain.chat_models import ChatOpenAI, ChatOllama
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from typing import List, Any
from utils import BaseLogger
def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}):
if embedding_model_name == "ollama":
embeddings = OllamaEmbeddings(base_url=config.ollama_base_url, model="llama2")
dimension = 4096
logger.info("Embedding: Using Ollama")
elif embedding_model_name == "openai":
embeddings = OpenAIEmbeddings()
dimension = 1536
logger.info("Embedding: Using OpenAI")
else:
embeddings = SentenceTransformerEmbeddings(
model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
)
dimension = 384
logger.info("Embedding: Using SentenceTransformer")
return embeddings, dimension
def load_llm(llm_name: str, logger=BaseLogger(), config={}):
if llm_name == "gpt-4":
logger.info("LLM: Using GPT-4")
return ChatOpenAI(temperature=0, model_name="gpt-4", streaming=True)
elif llm_name == "gpt-3.5":
logger.info("LLM: Using GPT-3.5")
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
elif len(llm_name):
logger.info(f"LLM: Using Ollama: {llm_name}")
return ChatOllama(
temperature=0,
base_url=config["ollama_base_url"],
model=llm_name,
streaming=True,
top_k=10, # A higher value (100) will give more diverse answers, while a lower value (10) will be more conservative.
top_p=0.3, # Higher value (0.95) will lead to more diverse text, while a lower value (0.5) will generate more focused text.
num_ctx=3072, # Sets the size of the context window used to generate the next token.
)
logger.info("LLM: Using GPT-3.5")
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
def configure_llm_only_chain(llm):
# LLM only response
template = """
You are a helpful assistant that helps a support agent with answering programming questions.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
"""
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
human_template = "{text}"
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
chat_prompt = ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)
def generate_llm_output(
user_input: str, callbacks: List[Any], prompt=chat_prompt
) -> str:
answer = llm(
prompt.format_prompt(
text=user_input,
).to_messages(),
callbacks=callbacks,
).content
return {"answer": answer}
return generate_llm_output
def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, password):
# RAG response
general_system_template = """
Use the following pieces of context to answer the question at the end.
The context contains question-answer pairs and their links from Stackoverflow.
You should prefer information from accepted or more upvoted answers.
Make sure to rely on information from the answers and not on questions to provide accuate responses.
When you find particular answer in the context useful, make sure to cite it in the answer using the link.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
----
{summaries}
----
Each answer you generate should contain a section at the end of links to
Stackoverflow questions and answers you found useful, which are described under Source value.
You can only use links to StackOverflow questions that are present in the context and always
add links to the end of the answer in the style of citations.
Generate concise answers with references sources section of links to
relevant StackOverflow questions only at the end of the answer.
"""
general_user_template = "Question:```{question}```"
messages = [
SystemMessagePromptTemplate.from_template(general_system_template),
HumanMessagePromptTemplate.from_template(general_user_template),
]
qa_prompt = ChatPromptTemplate.from_messages(messages)
qa_chain = load_qa_with_sources_chain(
llm,
chain_type="stuff",
prompt=qa_prompt,
)
# Vector + Knowledge Graph response
kg = Neo4jVector.from_existing_index(
embedding=embeddings,
url=embeddings_store_url,
username=username,
password=password,
database="neo4j", # neo4j by default
index_name="stackoverflow", # vector by default
text_node_property="body", # text by default
retrieval_query="""
WITH node AS question, score AS similarity
CALL { with question
MATCH (question)<-[:ANSWERS]-(answer)
WITH answer
ORDER BY answer.is_accepted DESC, answer.score DESC
WITH collect(answer)[..2] as answers
RETURN reduce(str='', answer IN answers | str +
'\n### Answer (Accepted: '+ answer.is_accepted +
' Score: ' + answer.score+ '): '+ answer.body + '\n') as answerTexts
}
RETURN '##Question: ' + question.title + '\n' + question.body + '\n'
+ answerTexts AS text, similarity as score, {source: question.link} AS metadata
ORDER BY similarity ASC // so that best answers are the last
""",
)
kg_qa = RetrievalQAWithSourcesChain(
combine_documents_chain=qa_chain,
retriever=kg.as_retriever(search_kwargs={"k": 2}),
reduce_k_below_max_tokens=False,
max_tokens_limit=3375,
)
return kg_qa

View File

@@ -14,6 +14,7 @@ RUN pip install --upgrade -r requirements.txt
COPY loader.py .
COPY utils.py .
COPY chains.py .
COPY images ./images
EXPOSE 8502

View File

@@ -4,7 +4,8 @@ from dotenv import load_dotenv
from langchain.graphs import Neo4jGraph
import streamlit as st
from streamlit.logger import get_logger
from utils import load_embedding_model
from chains import load_embedding_model
from utils import create_constraints, create_vector_index
from PIL import Image
load_dotenv(".env")
@@ -28,39 +29,8 @@ embeddings, dimension = load_embedding_model(
# if Neo4j is local, you can go to http://localhost:7474/ to browse the database
neo4j_graph = Neo4jGraph(url=url, username=username, password=password)
def create_constraints():
neo4j_graph.query(
"CREATE CONSTRAINT question_id IF NOT EXISTS FOR (q:Question) REQUIRE (q.id) IS UNIQUE"
)
neo4j_graph.query(
"CREATE CONSTRAINT answer_id IF NOT EXISTS FOR (a:Answer) REQUIRE (a.id) IS UNIQUE"
)
neo4j_graph.query(
"CREATE CONSTRAINT user_id IF NOT EXISTS FOR (u:User) REQUIRE (u.id) IS UNIQUE"
)
neo4j_graph.query(
"CREATE CONSTRAINT tag_name IF NOT EXISTS FOR (t:Tag) REQUIRE (t.name) IS UNIQUE"
)
create_constraints()
def create_vector_index(dimension):
index_query = "CALL db.index.vector.createNodeIndex('stackoverflow', 'Question', 'embedding', $dimension, 'cosine')"
try:
neo4j_graph.query(index_query, {"dimension": dimension})
except: # Already exists
pass
index_query = "CALL db.index.vector.createNodeIndex('top_answers', 'Answer', 'embedding', $dimension, 'cosine')"
try:
neo4j_graph.query(index_query, {"dimension": dimension})
except: # Already exists
pass
create_vector_index(dimension)
create_constraints(neo4j_graph)
create_vector_index(neo4j_graph, dimension)
def load_so_data(tag: str = "neo4j", page: int = 1) -> None:

View File

@@ -1,23 +1,6 @@
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import OllamaEmbeddings, SentenceTransformerEmbeddings
def load_embedding_model(embedding_model_name: str, config={}, logger=print):
if embedding_model_name == "ollama":
embeddings = OllamaEmbeddings(base_url=config.ollama_base_url, model="llama2")
dimension = 4096
logger.info("Embedding: Using Ollama")
elif embedding_model_name == "openai":
embeddings = OpenAIEmbeddings()
dimension = 1536
logger.info("Embedding: Using OpenAI")
else:
embeddings = SentenceTransformerEmbeddings(
model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
)
dimension = 384
logger.info("Embedding: Using SentenceTransformer")
return embeddings, dimension
class BaseLogger:
def __init__(self) -> None:
self.info = print
def extract_title_and_question(input_string):
@@ -41,3 +24,31 @@ def extract_title_and_question(input_string):
question += "\n" + line.strip()
return title, question
def create_vector_index(driver, dimension: int) -> None:
index_query = "CALL db.index.vector.createNodeIndex('stackoverflow', 'Question', 'embedding', $dimension, 'cosine')"
try:
driver.query(index_query, {"dimension": dimension})
except: # Already exists
pass
index_query = "CALL db.index.vector.createNodeIndex('top_answers', 'Answer', 'embedding', $dimension, 'cosine')"
try:
driver.query(index_query, {"dimension": dimension})
except: # Already exists
pass
def create_constraints(driver):
driver.query(
"CREATE CONSTRAINT question_id IF NOT EXISTS FOR (q:Question) REQUIRE (q.id) IS UNIQUE"
)
driver.query(
"CREATE CONSTRAINT answer_id IF NOT EXISTS FOR (a:Answer) REQUIRE (a.id) IS UNIQUE"
)
driver.query(
"CREATE CONSTRAINT user_id IF NOT EXISTS FOR (u:User) REQUIRE (u.id) IS UNIQUE"
)
driver.query(
"CREATE CONSTRAINT tag_name IF NOT EXISTS FOR (t:Tag) REQUIRE (t.name) IS UNIQUE"
)