mirror of
https://github.com/HKUDS/VideoRAG.git
synced 2025-05-11 03:54:36 +03:00
Fix execution errors
This commit is contained in:
@@ -3,38 +3,44 @@ import json
|
||||
import logging
|
||||
import warnings
|
||||
import multiprocessing
|
||||
import sys
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
# Add the parent directory of 'videorag' to sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Set sub-category and CUDA device.")
|
||||
parser.add_argument('--collection', type=str, default='4-rag-lecture')
|
||||
parser.add_argument('--cuda', type=str, default='0')
|
||||
args = parser.parse_args()
|
||||
sub_category = args.sub_category
|
||||
sub_category = args.collection
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
|
||||
from videorag._llm import *
|
||||
from videorag import VideoRAG, QueryParam
|
||||
from videorag.videorag import VideoRAG, QueryParam
|
||||
|
||||
if __name__ == '__main__':
|
||||
multiprocessing.set_start_method('spawn')
|
||||
|
||||
## learn
|
||||
video_base_path = f'./{sub_category}/videos/'
|
||||
video_base_path = f'longervideos/{sub_category}/videos/'
|
||||
video_files = sorted(os.listdir(video_base_path))
|
||||
video_paths = [os.path.join(video_base_path, f) for f in video_files]
|
||||
videorag = VideoRAG(cheap_model_func=gpt_4o_mini_complete, best_model_func=gpt_4o_mini_complete, working_dir=f"./videorag-workdir/{sub_category}")
|
||||
#videorag = VideoRAG(cheap_model_func=gpt_4o_mini_complete, best_model_func=gpt_4o_mini_complete, working_dir=f"./videorag-workdir/{sub_category}")
|
||||
videorag = VideoRAG(cheap_model_func=ollama_mini_complete, best_model_func=ollama_complete, working_dir=f"./videorag-workdir/{sub_category}")
|
||||
videorag.insert_video(video_path_list=video_paths)
|
||||
|
||||
## inference
|
||||
with open(f'./dataset.json', 'r') as f:
|
||||
with open(f'longervideos/dataset.json', 'r') as f:
|
||||
longervideos = json.load(f)
|
||||
|
||||
videorag = VideoRAG(cheap_model_func=gpt_4o_mini_complete, best_model_func=gpt_4o_mini_complete, working_dir=f"./videorag-workdir/{sub_category}")
|
||||
#videorag = VideoRAG(cheap_model_func=gpt_4o_mini_complete, best_model_func=gpt_4o_mini_complete, working_dir=f"./videorag-workdir/{sub_category}")
|
||||
videorag = VideoRAG(cheap_model_func=ollama_mini_complete, best_model_func=ollama_complete, working_dir=f"./videorag-workdir/{sub_category}")
|
||||
videorag.load_caption_model(debug=False)
|
||||
|
||||
answer_folder = f'./videorag-answers/{sub_category}'
|
||||
@@ -51,5 +57,5 @@ if __name__ == '__main__':
|
||||
|
||||
response = videorag.query(query=query, param=param)
|
||||
print(response)
|
||||
with open(os.path.join(answer_folder, f'/answer_{query_id}.md'), 'w') as f:
|
||||
f.write(response)
|
||||
with open(os.path.join(answer_folder, f'answer_{query_id}.md'), 'w') as f:
|
||||
f.write(response)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError
|
||||
from ollama import AsyncClient
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
@@ -33,8 +34,8 @@ def get_azure_openai_async_client_instance():
|
||||
def get_ollama_async_client_instance():
|
||||
global global_ollama_client
|
||||
if global_ollama_client is None:
|
||||
#global_ollama_client = Client(base_url="http://localhost:11434") # Adjust base URL if necessary
|
||||
global_ollama_client = Client(base_url="http://10.0.1.12:11434") # Adjust base URL if necessary
|
||||
#global_ollama_client = AsyncClient(host="http://localhost:11434") # Adjust base URL if necessary
|
||||
global_ollama_client = AsyncClient(host="http://10.0.1.12:11434") # Adjust base URL if necessary
|
||||
return global_ollama_client
|
||||
|
||||
@retry(
|
||||
@@ -238,18 +239,23 @@ async def ollama_mini_complete(prompt, system_prompt=None, history_messages=[],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
|
||||
)
|
||||
async def ollama_embedding(texts: list[str]) -> np.ndarray:
|
||||
# Initialize the Ollama client
|
||||
ollama_client = get_ollama_async_client_instance()
|
||||
|
||||
# Send the request to Ollama for embeddings
|
||||
response = await ollama_client.embeddings(
|
||||
response = await ollama_client.embed(
|
||||
model="nomic-embed-text", # Replace with the appropriate Ollama embedding model
|
||||
input=texts,
|
||||
encoding_format="float"
|
||||
input=texts
|
||||
)
|
||||
|
||||
# Extract embeddings from the response
|
||||
embeddings = [dp.embedding for dp in response.data]
|
||||
embeddings = response['embeddings']
|
||||
|
||||
return np.array(embeddings)
|
||||
|
||||
@@ -20,6 +20,7 @@ from ._llm import (
|
||||
azure_openai_embedding,
|
||||
azure_gpt_4o_mini_complete,
|
||||
ollama_complete,
|
||||
ollama_mini_complete,
|
||||
ollama_embedding
|
||||
)
|
||||
from ._op import (
|
||||
@@ -121,7 +122,7 @@ class VideoRAG:
|
||||
cheap_model_max_async: int = 16
|
||||
if llm_provider == "azur_openai":
|
||||
# text embedding
|
||||
embedding_func = : EmbeddingFunc = field(default_factory=lambda: azure_openai_embedding)
|
||||
embedding_func: EmbeddingFunc = field(default_factory=lambda: azure_openai_embedding)
|
||||
embedding_batch_num: int = 32
|
||||
embedding_func_max_async: int = 16
|
||||
query_better_than_threshold: float = 0.2
|
||||
@@ -138,7 +139,7 @@ class VideoRAG:
|
||||
if llm_provider == "ollama":
|
||||
# text embedding
|
||||
embedding_func: EmbeddingFunc = field(default_factory=lambda: ollama_embedding)
|
||||
embedding_batch_num: int = 32
|
||||
embedding_batch_num: int = 1
|
||||
embedding_func_max_async: int = 1
|
||||
query_better_than_threshold: float = 0.2
|
||||
|
||||
|
||||
Reference in New Issue
Block a user