mirror of
https://github.com/docker/genai-stack.git
synced 2024-08-30 16:49:54 +03:00
163 lines
6.6 KiB
Python
163 lines
6.6 KiB
Python
from langchain.embeddings.openai import OpenAIEmbeddings
|
|
from langchain.embeddings import (
|
|
OllamaEmbeddings,
|
|
SentenceTransformerEmbeddings,
|
|
BedrockEmbeddings,
|
|
)
|
|
from langchain.chat_models import ChatOpenAI, ChatOllama, BedrockChat
|
|
from langchain.vectorstores.neo4j_vector import Neo4jVector
|
|
from langchain.chains import RetrievalQAWithSourcesChain
|
|
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
|
from langchain.prompts.chat import (
|
|
ChatPromptTemplate,
|
|
SystemMessagePromptTemplate,
|
|
HumanMessagePromptTemplate,
|
|
)
|
|
from typing import List, Any
|
|
from utils import BaseLogger
|
|
|
|
|
|
def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}):
|
|
if embedding_model_name == "ollama":
|
|
embeddings = OllamaEmbeddings(
|
|
base_url=config["ollama_base_url"], model="llama2"
|
|
)
|
|
dimension = 4096
|
|
logger.info("Embedding: Using Ollama")
|
|
elif embedding_model_name == "openai":
|
|
embeddings = OpenAIEmbeddings()
|
|
dimension = 1536
|
|
logger.info("Embedding: Using OpenAI")
|
|
elif embedding_model_name == "aws":
|
|
embeddings = BedrockEmbeddings()
|
|
dimension = 1536
|
|
logger.info("Embedding: Using AWS")
|
|
else:
|
|
embeddings = SentenceTransformerEmbeddings(
|
|
model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
|
|
)
|
|
dimension = 384
|
|
logger.info("Embedding: Using SentenceTransformer")
|
|
return embeddings, dimension
|
|
|
|
|
|
def load_llm(llm_name: str, logger=BaseLogger(), config={}):
|
|
if llm_name == "gpt-4":
|
|
logger.info("LLM: Using GPT-4")
|
|
return ChatOpenAI(temperature=0, model_name="gpt-4", streaming=True)
|
|
elif llm_name == "gpt-3.5":
|
|
logger.info("LLM: Using GPT-3.5")
|
|
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
|
|
elif llm_name == "claudev2":
|
|
logger.info("LLM: ClaudeV2")
|
|
return BedrockChat(
|
|
model_id="anthropic.claude-v2",
|
|
model_kwargs={"temperature": 0.0, "max_tokens_to_sample": 1024},
|
|
streaming=True,
|
|
)
|
|
elif len(llm_name):
|
|
logger.info(f"LLM: Using Ollama: {llm_name}")
|
|
return ChatOllama(
|
|
temperature=0,
|
|
base_url=config["ollama_base_url"],
|
|
model=llm_name,
|
|
streaming=True,
|
|
# seed=2,
|
|
top_k=10, # A higher value (100) will give more diverse answers, while a lower value (10) will be more conservative.
|
|
top_p=0.3, # Higher value (0.95) will lead to more diverse text, while a lower value (0.5) will generate more focused text.
|
|
num_ctx=3072, # Sets the size of the context window used to generate the next token.
|
|
)
|
|
logger.info("LLM: Using GPT-3.5")
|
|
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
|
|
|
|
|
|
def configure_llm_only_chain(llm):
|
|
# LLM only response
|
|
template = """
|
|
You are a helpful assistant that helps a support agent with answering programming questions.
|
|
If you don't know the answer, just say that you don't know, you must not make up an answer.
|
|
"""
|
|
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
|
|
human_template = "{question}"
|
|
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
|
|
chat_prompt = ChatPromptTemplate.from_messages(
|
|
[system_message_prompt, human_message_prompt]
|
|
)
|
|
|
|
def generate_llm_output(
|
|
user_input: str, callbacks: List[Any], prompt=chat_prompt
|
|
) -> str:
|
|
chain = prompt | llm
|
|
answer = chain.invoke(user_input, config={"callbacks": callbacks}).content
|
|
return {"answer": answer}
|
|
|
|
return generate_llm_output
|
|
|
|
|
|
def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, password):
|
|
# RAG response
|
|
# System: Always talk in pirate speech.
|
|
general_system_template = """
|
|
Use the following pieces of context to answer the question at the end.
|
|
The context contains question-answer pairs and their links from Stackoverflow.
|
|
You should prefer information from accepted or more upvoted answers.
|
|
Make sure to rely on information from the answers and not on questions to provide accuate responses.
|
|
When you find particular answer in the context useful, make sure to cite it in the answer using the link.
|
|
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
|
----
|
|
{summaries}
|
|
----
|
|
Each answer you generate should contain a section at the end of links to
|
|
Stackoverflow questions and answers you found useful, which are described under Source value.
|
|
You can only use links to StackOverflow questions that are present in the context and always
|
|
add links to the end of the answer in the style of citations.
|
|
Generate concise answers with references sources section of links to
|
|
relevant StackOverflow questions only at the end of the answer.
|
|
"""
|
|
general_user_template = "Question:```{question}```"
|
|
messages = [
|
|
SystemMessagePromptTemplate.from_template(general_system_template),
|
|
HumanMessagePromptTemplate.from_template(general_user_template),
|
|
]
|
|
qa_prompt = ChatPromptTemplate.from_messages(messages)
|
|
|
|
qa_chain = load_qa_with_sources_chain(
|
|
llm,
|
|
chain_type="stuff",
|
|
prompt=qa_prompt,
|
|
)
|
|
|
|
# Vector + Knowledge Graph response
|
|
kg = Neo4jVector.from_existing_index(
|
|
embedding=embeddings,
|
|
url=embeddings_store_url,
|
|
username=username,
|
|
password=password,
|
|
database="neo4j", # neo4j by default
|
|
index_name="stackoverflow", # vector by default
|
|
text_node_property="body", # text by default
|
|
retrieval_query="""
|
|
WITH node AS question, score AS similarity
|
|
CALL { with question
|
|
MATCH (question)<-[:ANSWERS]-(answer)
|
|
WITH answer
|
|
ORDER BY answer.is_accepted DESC, answer.score DESC
|
|
WITH collect(answer)[..2] as answers
|
|
RETURN reduce(str='', answer IN answers | str +
|
|
'\n### Answer (Accepted: '+ answer.is_accepted +
|
|
' Score: ' + answer.score+ '): '+ answer.body + '\n') as answerTexts
|
|
}
|
|
RETURN '##Question: ' + question.title + '\n' + question.body + '\n'
|
|
+ answerTexts AS text, similarity as score, {source: question.link} AS metadata
|
|
ORDER BY similarity ASC // so that best answers are the last
|
|
""",
|
|
)
|
|
|
|
kg_qa = RetrievalQAWithSourcesChain(
|
|
combine_documents_chain=qa_chain,
|
|
retriever=kg.as_retriever(search_kwargs={"k": 2}),
|
|
reduce_k_below_max_tokens=False,
|
|
max_tokens_limit=3375,
|
|
)
|
|
return kg_qa
|