mirror of
https://github.com/BaranziniLab/KG_RAG.git
synced 2024-06-08 14:12:54 +03:00
added llama method to interactive fn
This commit is contained in:
@@ -45,7 +45,7 @@ def main():
|
||||
print("")
|
||||
output = llm_chain.run(context=context, question=question)
|
||||
else:
|
||||
interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, "llama")
|
||||
interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, "llama", llama_method=METHOD)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ if user_input == "Y":
|
||||
print("""
|
||||
IMPORTANT :
|
||||
Llama model was downloaded using 'LlamaTokenizer' instead of 'AutoTokenizer' method.
|
||||
So, when you run text generation script, please provide an extra command line argument 'method-2'.
|
||||
So, when you run text generation script, please provide an extra command line argument '-m method-2'.
|
||||
For example:
|
||||
python -m kg_rag.rag_based_generation.Llama.text_generation -m method-2
|
||||
""")
|
||||
|
||||
@@ -296,7 +296,7 @@ def retrieve_context(question, vectorstore, embedding_function, node_context_df,
|
||||
return node_context_extracted
|
||||
|
||||
|
||||
def interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, llm_type, api=True):
|
||||
def interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, llm_type, api=True, llama_method="method-1"):
|
||||
input("Press enter for Step 1 - Disease entity extraction using GPT-3.5-Turbo")
|
||||
print("Processing ...")
|
||||
entities = disease_entity_extractor_v2(question)
|
||||
@@ -353,7 +353,7 @@ def interactive(question, vectorstore, node_context_df, embedding_function_for_c
|
||||
from langchain import PromptTemplate, LLMChain
|
||||
template = get_prompt("Context:\n\n{context} \n\nQuestion: {question}", system_prompts["KG_RAG_BASED_TEXT_GENERATION"])
|
||||
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
|
||||
llm = llama_model(config_data["LLAMA_MODEL_NAME"], config_data["LLAMA_MODEL_BRANCH"], config_data["LLM_CACHE_DIR"], stream=True)
|
||||
llm = llama_model(config_data["LLAMA_MODEL_NAME"], config_data["LLAMA_MODEL_BRANCH"], config_data["LLM_CACHE_DIR"], stream=True, method=llama_method)
|
||||
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
||||
output = llm_chain.run(context=node_context_extracted, question=question)
|
||||
elif "gpt" in llm_type:
|
||||
|
||||
Reference in New Issue
Block a user