Fixed issues due to refactoring of configuration.

Q&A works, still need to test building
This commit is contained in:
Gerald Hewes
2025-02-18 06:21:49 -05:00
parent 91929202ee
commit a712ce52c0
5 changed files with 34 additions and 35 deletions

View File

@@ -21,7 +21,7 @@ sub_category = args.collection
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
os.environ["OPENAI_API_KEY"] = ""
from videorag._llm import *
from videorag._llm import openai_config, azure_openai_config, ollama_config
from videorag.videorag import VideoRAG, QueryParam
if __name__ == '__main__':
@@ -31,8 +31,7 @@ if __name__ == '__main__':
video_base_path = f'longervideos/{sub_category}/videos/'
video_files = sorted(os.listdir(video_base_path))
video_paths = [os.path.join(video_base_path, f) for f in video_files]
#videorag = VideoRAG(cheap_model_func=gpt_4o_mini_complete, best_model_func=gpt_4o_mini_complete, working_dir=f"./videorag-workdir/{sub_category}")
videorag = VideoRAG(cheap_model_func=ollama_mini_complete, best_model_func=ollama_complete, working_dir=f"./videorag-workdir/{sub_category}")
videorag = VideoRAG(llm=ollama_config, working_dir=f"./videorag-workdir/{sub_category}")
videorag.insert_video(video_path_list=video_paths)
## inference

View File

@@ -2,6 +2,7 @@ import numpy as np
from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError
from ollama import AsyncClient
from dataclasses import asdict, dataclass, field
from tenacity import (
retry,
@@ -13,6 +14,7 @@ import os
from ._utils import compute_args_hash, wrap_embedding_func_with_attrs
from .base import BaseKVStorage
from ._utils import EmbeddingFunc
global_openai_async_client = None
global_azure_openai_async_client = None
@@ -130,19 +132,19 @@ async def openai_embedding(texts: list[str]) -> np.ndarray:
openai_config = LLMConfig(
embedding_func = field(default_factory=lambda: openai_embedding)
embedding_batch_num = 32
embedding_func_max_async = 16
query_better_than_threshold = 0.2
embedding_func = openai_embedding,
embedding_batch_num = 32,
embedding_func_max_async = 16,
query_better_than_threshold = 0.2,
# LLM
best_model_func = gpt_4o_mini_complete
best_model_max_token_size = 32768
best_model_max_async = 16
best_model_func = gpt_4o_mini_complete,
best_model_max_token_size = 32768,
best_model_max_async = 16,
cheap_model_func = gpt_4o_mini_complete
cheap_model_max_token_size = 32768
cheap_model_max_async = 16
cheap_model_func = gpt_4o_mini_complete,
cheap_model_max_token_size = 32768,
cheap_model_max_async = 16)
###### Azure OpenAI Configuration
@retry(
@@ -223,18 +225,18 @@ async def azure_openai_embedding(texts: list[str]) -> np.ndarray:
azure_openai_config = LLMConfig(
embedding_func = field(default_factory=lambda: azure_openai_embedding),
embedding_batch_num = 32
embedding_func_max_async = 16
query_better_than_threshold = 0.2
embedding_func = azure_openai_embedding,
embedding_batch_num = 32,
embedding_func_max_async = 16,
query_better_than_threshold = 0.2,
best_model_func: callable = azure_gpt_4o_complete
best_model_max_token_size = 32768
best_model_max_async = 16
best_model_func = azure_gpt_4o_complete,
best_model_max_token_size = 32768,
best_model_max_async = 16,
cheap_model_func: callable = azure_gpt_4o_mini_complete
cheap_model_max_token_size = 32768
cheap_model_max_async = 16
cheap_model_func = azure_gpt_4o_mini_complete,
cheap_model_max_token_size = 32768,
cheap_model_max_async = 16)
###### Ollama configuration
@@ -317,12 +319,12 @@ async def ollama_embedding(texts: list[str]) -> np.ndarray:
return np.array(embeddings)
ollama_config = LLMConfig(
embedding_func= EmbeddingFunc = field(default_factory=lambda: ollama_embedding),
embedding_func = ollama_embedding,
embedding_batch_num = 1,
embedding_func_max_async = 1,
query_better_than_threshold = 0.2,
best_model_func = ollama_complete ,
best_model_max_token_size: int = 32768,
best_model_max_token_size = 32768,
best_model_max_async = 1,
cheap_model_func = ollama_mini_complete,
cheap_model_max_token_size = 32768,

View File

@@ -552,7 +552,7 @@ async def _refine_entity_retrieval_query(
query_param: QueryParam,
global_config: dict,
):
use_llm_func: callable = global_config["cheap_model_func"]
use_llm_func: callable = global_config["llm"]["cheap_model_func"]
query_rewrite_prompt = PROMPTS["query_rewrite_for_entity_retrieval"]
query_rewrite_prompt = query_rewrite_prompt.format(input_text=query)
final_result = await use_llm_func(query_rewrite_prompt)
@@ -563,7 +563,7 @@ async def _refine_visual_retrieval_query(
query_param: QueryParam,
global_config: dict,
):
use_llm_func: callable = global_config["cheap_model_func"]
use_llm_func: callable = global_config["llm"]["cheap_model_func"]
query_rewrite_prompt = PROMPTS["query_rewrite_for_visual_retrieval"]
query_rewrite_prompt = query_rewrite_prompt.format(input_text=query)
final_result = await use_llm_func(query_rewrite_prompt)
@@ -574,7 +574,7 @@ async def _extract_keywords_query(
query_param: QueryParam,
global_config: dict,
):
use_llm_func: callable = global_config["cheap_model_func"]
use_llm_func: callable = global_config["llm"]["cheap_model_func"]
keywords_prompt = PROMPTS["keywords_extraction"]
keywords_prompt = keywords_prompt.format(input_text=query)
final_result = await use_llm_func(keywords_prompt)
@@ -594,7 +594,7 @@ async def videorag_query(
query_param: QueryParam,
global_config: dict,
) -> str:
use_model_func = global_config["best_model_func"]
use_model_func = global_config["llm"]["best_model_func"]
query = query
# naive chunks

View File

@@ -21,7 +21,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
self._max_batch_size = self.global_config["llm"]["embedding_batch_num"]
self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name
)
@@ -142,4 +142,4 @@ class NanoVectorDBVideoSegmentStorage(BaseVectorStorage):
return results
async def index_done_callback(self):
self._client.save()
self._client.save()

View File

@@ -13,7 +13,7 @@ import tiktoken
from ._llm import (
LLMConfig
LLMConfig,
openai_config,
azure_openai_config,
ollama_config
@@ -96,10 +96,8 @@ class VideoRAG:
entity_extract_max_gleaning: int = 1
entity_summary_to_max_tokens: int = 500
# Uncomment as appropriate depending on whether you use openai, azure_openai or ollama
# Change to your LLM provider
llm: LLMConfig = ollama_config
llm: LLMConfig = field(default_factory=openai_config)
# entity extraction
entity_extraction_func: callable = extract_entities