added llama method to interactive fn

This commit is contained in:
Karthik Soman
2023-12-05 23:55:53 -08:00
parent 47d652d95a
commit f32580d22a
3 changed files with 4 additions and 4 deletions

View File

@@ -45,7 +45,7 @@ def main():
print("")
output = llm_chain.run(context=context, question=question)
else:
interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, "llama")
interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, "llama", llama_method=METHOD)

View File

@@ -53,7 +53,7 @@ if user_input == "Y":
print("""
IMPORTANT :
Llama model was downloaded using 'LlamaTokenizer' instead of 'AutoTokenizer' method.
So, when you run text generation script, please provide an extra command line argument 'method-2'.
So, when you run text generation script, please provide an extra command line argument '-m method-2'.
For example:
python -m kg_rag.rag_based_generation.Llama.text_generation -m method-2
""")

View File

@@ -296,7 +296,7 @@ def retrieve_context(question, vectorstore, embedding_function, node_context_df,
return node_context_extracted
def interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, llm_type, api=True):
def interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, llm_type, api=True, llama_method="method-1"):
input("Press enter for Step 1 - Disease entity extraction using GPT-3.5-Turbo")
print("Processing ...")
entities = disease_entity_extractor_v2(question)
@@ -353,7 +353,7 @@ def interactive(question, vectorstore, node_context_df, embedding_function_for_c
from langchain import PromptTemplate, LLMChain
template = get_prompt("Context:\n\n{context} \n\nQuestion: {question}", system_prompts["KG_RAG_BASED_TEXT_GENERATION"])
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
llm = llama_model(config_data["LLAMA_MODEL_NAME"], config_data["LLAMA_MODEL_BRANCH"], config_data["LLM_CACHE_DIR"], stream=True)
llm = llama_model(config_data["LLAMA_MODEL_NAME"], config_data["LLAMA_MODEL_BRANCH"], config_data["LLM_CACHE_DIR"], stream=True, method=llama_method)
llm_chain = LLMChain(prompt=prompt, llm=llm)
output = llm_chain.run(context=node_context_extracted, question=question)
elif "gpt" in llm_type: