mirror of
https://github.com/HKUDS/VideoRAG.git
synced 2025-05-11 03:54:36 +03:00
Fixed issues due to refactoring of configuration.
Q&A works, still need to test building
This commit is contained in:
@@ -21,7 +21,7 @@ sub_category = args.collection
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
|
||||
from videorag._llm import *
|
||||
from videorag._llm import openai_config, azure_openai_config, ollama_config
|
||||
from videorag.videorag import VideoRAG, QueryParam
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -31,8 +31,7 @@ if __name__ == '__main__':
|
||||
video_base_path = f'longervideos/{sub_category}/videos/'
|
||||
video_files = sorted(os.listdir(video_base_path))
|
||||
video_paths = [os.path.join(video_base_path, f) for f in video_files]
|
||||
#videorag = VideoRAG(cheap_model_func=gpt_4o_mini_complete, best_model_func=gpt_4o_mini_complete, working_dir=f"./videorag-workdir/{sub_category}")
|
||||
videorag = VideoRAG(cheap_model_func=ollama_mini_complete, best_model_func=ollama_complete, working_dir=f"./videorag-workdir/{sub_category}")
|
||||
videorag = VideoRAG(llm=ollama_config, working_dir=f"./videorag-workdir/{sub_category}")
|
||||
videorag.insert_video(video_path_list=video_paths)
|
||||
|
||||
## inference
|
||||
|
||||
@@ -2,6 +2,7 @@ import numpy as np
|
||||
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError
|
||||
from ollama import AsyncClient
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
@@ -13,6 +14,7 @@ import os
|
||||
|
||||
from ._utils import compute_args_hash, wrap_embedding_func_with_attrs
|
||||
from .base import BaseKVStorage
|
||||
from ._utils import EmbeddingFunc
|
||||
|
||||
global_openai_async_client = None
|
||||
global_azure_openai_async_client = None
|
||||
@@ -130,19 +132,19 @@ async def openai_embedding(texts: list[str]) -> np.ndarray:
|
||||
|
||||
|
||||
openai_config = LLMConfig(
|
||||
embedding_func = field(default_factory=lambda: openai_embedding)
|
||||
embedding_batch_num = 32
|
||||
embedding_func_max_async = 16
|
||||
query_better_than_threshold = 0.2
|
||||
embedding_func = openai_embedding,
|
||||
embedding_batch_num = 32,
|
||||
embedding_func_max_async = 16,
|
||||
query_better_than_threshold = 0.2,
|
||||
|
||||
# LLM
|
||||
best_model_func = gpt_4o_mini_complete
|
||||
best_model_max_token_size = 32768
|
||||
best_model_max_async = 16
|
||||
best_model_func = gpt_4o_mini_complete,
|
||||
best_model_max_token_size = 32768,
|
||||
best_model_max_async = 16,
|
||||
|
||||
cheap_model_func = gpt_4o_mini_complete
|
||||
cheap_model_max_token_size = 32768
|
||||
cheap_model_max_async = 16
|
||||
cheap_model_func = gpt_4o_mini_complete,
|
||||
cheap_model_max_token_size = 32768,
|
||||
cheap_model_max_async = 16)
|
||||
|
||||
###### Azure OpenAI Configuration
|
||||
@retry(
|
||||
@@ -223,18 +225,18 @@ async def azure_openai_embedding(texts: list[str]) -> np.ndarray:
|
||||
|
||||
|
||||
azure_openai_config = LLMConfig(
|
||||
embedding_func = field(default_factory=lambda: azure_openai_embedding),
|
||||
embedding_batch_num = 32
|
||||
embedding_func_max_async = 16
|
||||
query_better_than_threshold = 0.2
|
||||
embedding_func = azure_openai_embedding,
|
||||
embedding_batch_num = 32,
|
||||
embedding_func_max_async = 16,
|
||||
query_better_than_threshold = 0.2,
|
||||
|
||||
best_model_func: callable = azure_gpt_4o_complete
|
||||
best_model_max_token_size = 32768
|
||||
best_model_max_async = 16
|
||||
best_model_func = azure_gpt_4o_complete,
|
||||
best_model_max_token_size = 32768,
|
||||
best_model_max_async = 16,
|
||||
|
||||
cheap_model_func: callable = azure_gpt_4o_mini_complete
|
||||
cheap_model_max_token_size = 32768
|
||||
cheap_model_max_async = 16
|
||||
cheap_model_func = azure_gpt_4o_mini_complete,
|
||||
cheap_model_max_token_size = 32768,
|
||||
cheap_model_max_async = 16)
|
||||
|
||||
|
||||
###### Ollama configuration
|
||||
@@ -317,12 +319,12 @@ async def ollama_embedding(texts: list[str]) -> np.ndarray:
|
||||
return np.array(embeddings)
|
||||
|
||||
ollama_config = LLMConfig(
|
||||
embedding_func= EmbeddingFunc = field(default_factory=lambda: ollama_embedding),
|
||||
embedding_func = ollama_embedding,
|
||||
embedding_batch_num = 1,
|
||||
embedding_func_max_async = 1,
|
||||
query_better_than_threshold = 0.2,
|
||||
best_model_func = ollama_complete ,
|
||||
best_model_max_token_size: int = 32768,
|
||||
best_model_max_token_size = 32768,
|
||||
best_model_max_async = 1,
|
||||
cheap_model_func = ollama_mini_complete,
|
||||
cheap_model_max_token_size = 32768,
|
||||
|
||||
@@ -552,7 +552,7 @@ async def _refine_entity_retrieval_query(
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
):
|
||||
use_llm_func: callable = global_config["cheap_model_func"]
|
||||
use_llm_func: callable = global_config["llm"]["cheap_model_func"]
|
||||
query_rewrite_prompt = PROMPTS["query_rewrite_for_entity_retrieval"]
|
||||
query_rewrite_prompt = query_rewrite_prompt.format(input_text=query)
|
||||
final_result = await use_llm_func(query_rewrite_prompt)
|
||||
@@ -563,7 +563,7 @@ async def _refine_visual_retrieval_query(
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
):
|
||||
use_llm_func: callable = global_config["cheap_model_func"]
|
||||
use_llm_func: callable = global_config["llm"]["cheap_model_func"]
|
||||
query_rewrite_prompt = PROMPTS["query_rewrite_for_visual_retrieval"]
|
||||
query_rewrite_prompt = query_rewrite_prompt.format(input_text=query)
|
||||
final_result = await use_llm_func(query_rewrite_prompt)
|
||||
@@ -574,7 +574,7 @@ async def _extract_keywords_query(
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
):
|
||||
use_llm_func: callable = global_config["cheap_model_func"]
|
||||
use_llm_func: callable = global_config["llm"]["cheap_model_func"]
|
||||
keywords_prompt = PROMPTS["keywords_extraction"]
|
||||
keywords_prompt = keywords_prompt.format(input_text=query)
|
||||
final_result = await use_llm_func(keywords_prompt)
|
||||
@@ -594,7 +594,7 @@ async def videorag_query(
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
) -> str:
|
||||
use_model_func = global_config["best_model_func"]
|
||||
use_model_func = global_config["llm"]["best_model_func"]
|
||||
query = query
|
||||
|
||||
# naive chunks
|
||||
|
||||
@@ -21,7 +21,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
self._max_batch_size = self.global_config["llm"]["embedding_batch_num"]
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
)
|
||||
@@ -142,4 +142,4 @@ class NanoVectorDBVideoSegmentStorage(BaseVectorStorage):
|
||||
return results
|
||||
|
||||
async def index_done_callback(self):
|
||||
self._client.save()
|
||||
self._client.save()
|
||||
|
||||
@@ -13,7 +13,7 @@ import tiktoken
|
||||
|
||||
|
||||
from ._llm import (
|
||||
LLMConfig
|
||||
LLMConfig,
|
||||
openai_config,
|
||||
azure_openai_config,
|
||||
ollama_config
|
||||
@@ -96,10 +96,8 @@ class VideoRAG:
|
||||
entity_extract_max_gleaning: int = 1
|
||||
entity_summary_to_max_tokens: int = 500
|
||||
|
||||
# Uncomment as appropriate depending on whether you use openai, azure_openai or ollama
|
||||
|
||||
# Change to your LLM provider
|
||||
llm: LLMConfig = ollama_config
|
||||
llm: LLMConfig = field(default_factory=openai_config)
|
||||
|
||||
# entity extraction
|
||||
entity_extraction_func: callable = extract_entities
|
||||
|
||||
Reference in New Issue
Block a user