mirror of
https://github.com/HKUDS/VideoRAG.git
synced 2025-05-11 03:54:36 +03:00
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
videorag/__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
46
README.md
46
README.md
@@ -114,7 +114,7 @@ logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
# Please enter your openai key
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
|
||||
from videorag._llm import *
|
||||
from videorag._llm import openai_4o_mini_config
|
||||
from videorag import VideoRAG, QueryParam
|
||||
|
||||
|
||||
@@ -127,7 +127,7 @@ if __name__ == '__main__':
|
||||
'movies/Iron-Man.mp4',
|
||||
'movies/Spider-Man.mkv',
|
||||
]
|
||||
videorag = VideoRAG(cheap_model_func=gpt_4o_mini_complete, best_model_func=gpt_4o_mini_complete, working_dir=f"./videorag-workdir")
|
||||
videorag = VideoRAG(llm=openai_4o_mini_config, working_dir=f"./videorag-workdir")
|
||||
videorag.insert_video(video_path_list=video_paths)
|
||||
```
|
||||
|
||||
@@ -156,7 +156,7 @@ if __name__ == '__main__':
|
||||
# if param.wo_reference = False, VideoRAG will add reference to video clips in the response
|
||||
param.wo_reference = True
|
||||
|
||||
videorag = VideoRAG(cheap_model_func=gpt_4o_mini_complete, best_model_func=gpt_4o_mini_complete, working_dir=f"./videorag-workdir")
|
||||
videorag = videorag = VideoRAG(llm=openai_4o_mini_config, working_dir=f"./videorag-workdir")
|
||||
videorag.load_caption_model(debug=False)
|
||||
response = videorag.query(query=query, param=param)
|
||||
print(response)
|
||||
@@ -187,7 +187,7 @@ sh download.sh # downloading videos
|
||||
Then, you can run the following example command to process and answer queries for LongerVideos with VideoRAG:
|
||||
|
||||
```shell
|
||||
# Please enter your openai_key in line 18 at first
|
||||
# Please enter your openai_key in line 22 at first
|
||||
python videorag_experiment.py --collection 4-rag-lecture --cuda 0
|
||||
```
|
||||
|
||||
@@ -244,6 +244,44 @@ python batch_winrate_quant_download.py
|
||||
python batch_winrate_quant_calculate.py
|
||||
```
|
||||
|
||||
### Ollama Support
|
||||
|
||||
This project also supports ollama. To use, edit the ollama_config in [_llm.py](VideoRAG/videorag/_llm.py).
|
||||
Adjust the paramters of the models being used
|
||||
|
||||
```
|
||||
ollama_config = LLMConfig(
|
||||
embedding_func_raw = ollama_embedding,
|
||||
embedding_model_name = "nomic-embed-text",
|
||||
embedding_dim = 768,
|
||||
embedding_max_token_size=8192,
|
||||
embedding_batch_num = 1,
|
||||
embedding_func_max_async = 1,
|
||||
query_better_than_threshold = 0.2,
|
||||
best_model_func_raw = ollama_complete ,
|
||||
best_model_name = "gemma2:latest", # need to be a solid instruct model
|
||||
best_model_max_token_size = 32768,
|
||||
best_model_max_async = 1,
|
||||
cheap_model_func_raw = ollama_mini_complete,
|
||||
cheap_model_name = "olmo2",
|
||||
cheap_model_max_token_size = 32768,
|
||||
cheap_model_max_async = 1
|
||||
)
|
||||
```
|
||||
And specify the config when creating your VideoRag instance
|
||||
|
||||
### Jupyter Notebook
|
||||
|
||||
To test the solution on a single video, just load the notebook in the [notebook folder](VideoRAG/nodebooks) and
|
||||
update the paramters to fit your situation.
|
||||
|
||||
YouTube video for example can be downloaded as follows:
|
||||
|
||||
```
|
||||
yt-dlp -o "%(id)s.%(ext)s" -S "res:720" https://www.youtube.com/live/DPa2iRgzadM?si=8cf8WbYtqiglrwtN -P .
|
||||
```
|
||||
|
||||
|
||||
## Citation
|
||||
If you find this work is helpful to your research, please consider citing our paper:
|
||||
```bibtex
|
||||
|
||||
@@ -3,41 +3,66 @@ import json
|
||||
import logging
|
||||
import warnings
|
||||
import multiprocessing
|
||||
import sys
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
# Add the parent directory of 'videorag' to sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Set sub-category and CUDA device.")
|
||||
parser.add_argument('--collection', type=str, default='4-rag-lecture')
|
||||
parser.add_argument('--cuda', type=str, default='0')
|
||||
args = parser.parse_args()
|
||||
sub_category = args.sub_category
|
||||
sub_category = args.collection
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
|
||||
from videorag._llm import *
|
||||
from videorag import VideoRAG, QueryParam
|
||||
from videorag.videorag import VideoRAG, QueryParam
|
||||
|
||||
longervideos_llm_config = LLMConfig(
|
||||
embedding_func_raw = openai_embedding,
|
||||
embedding_model_name = "text-embedding-3-small",
|
||||
embedding_dim = 1536,
|
||||
embedding_max_token_size = 8192,
|
||||
embedding_batch_num = 32,
|
||||
embedding_func_max_async = 16,
|
||||
query_better_than_threshold = 0.2,
|
||||
|
||||
# LLM (we utilize gpt-4o-mini for all experiments)
|
||||
best_model_func_raw = gpt_4o_mini_complete,
|
||||
best_model_name = "gpt-4o-mini",
|
||||
best_model_max_token_size = 32768,
|
||||
best_model_max_async = 16,
|
||||
|
||||
cheap_model_func_raw = gpt_4o_mini_complete,
|
||||
cheap_model_name = "gpt-4o-mini",
|
||||
cheap_model_max_token_size = 32768,
|
||||
cheap_model_max_async = 16
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
multiprocessing.set_start_method('spawn')
|
||||
|
||||
## learn
|
||||
video_base_path = f'./{sub_category}/videos/'
|
||||
video_base_path = f'./longervideos/{sub_category}/videos/'
|
||||
video_files = sorted(os.listdir(video_base_path))
|
||||
video_paths = [os.path.join(video_base_path, f) for f in video_files]
|
||||
videorag = VideoRAG(cheap_model_func=gpt_4o_mini_complete, best_model_func=gpt_4o_mini_complete, working_dir=f"./videorag-workdir/{sub_category}")
|
||||
videorag = VideoRAG(llm=longervideos_llm_config, working_dir=f"./longervideos/videorag-workdir/{sub_category}")
|
||||
videorag.insert_video(video_path_list=video_paths)
|
||||
|
||||
## inference
|
||||
with open(f'./dataset.json', 'r') as f:
|
||||
with open(f'./longervideos/dataset.json', 'r') as f:
|
||||
longervideos = json.load(f)
|
||||
|
||||
videorag = VideoRAG(cheap_model_func=gpt_4o_mini_complete, best_model_func=gpt_4o_mini_complete, working_dir=f"./videorag-workdir/{sub_category}")
|
||||
videorag = VideoRAG(llm=longervideos_llm_config, working_dir=f"./longervideos/videorag-workdir/{sub_category}")
|
||||
videorag.load_caption_model(debug=False)
|
||||
|
||||
answer_folder = f'./videorag-answers/{sub_category}'
|
||||
answer_folder = f'./longervideos/videorag-answers/{sub_category}'
|
||||
os.makedirs(answer_folder, exist_ok=True)
|
||||
|
||||
collection_id = sub_category.split('-')[0]
|
||||
@@ -51,5 +76,5 @@ if __name__ == '__main__':
|
||||
|
||||
response = videorag.query(query=query, param=param)
|
||||
print(response)
|
||||
with open(os.path.join(answer_folder, f'/answer_{query_id}.md'), 'w') as f:
|
||||
f.write(response)
|
||||
with open(os.path.join(answer_folder, f'answer_{query_id}.md'), 'w') as f:
|
||||
f.write(response)
|
||||
|
||||
124
notesbooks/videorag.ipynb
Normal file
124
notesbooks/videorag.ipynb
Normal file
@@ -0,0 +1,124 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c1e12bf6-a8ae-4fec-9fbd-99ea74bcc563",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import logging\n",
|
||||
"import warnings\n",
|
||||
"import multiprocessing\n",
|
||||
"import nest_asyncio\n",
|
||||
" \n",
|
||||
"nest_asyncio.apply()\n",
|
||||
"\n",
|
||||
"warnings.filterwarnings(\"ignore\")\n",
|
||||
"logging.getLogger(\"httpx\").setLevel(logging.WARNING)\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n",
|
||||
"\n",
|
||||
"from videorag._llm import openai_config, openai_4o_mini_config, azure_openai_config, ollama_config\n",
|
||||
"from videorag import VideoRAG, QueryParam\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f04cee24-fd8e-41dc-93ea-888229d2a9af",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"video_paths = [\n",
|
||||
" '/mnt/data3/AI/software/VideoRAG/Lexington/GMT20241112-164602_Recording_gallery_1280x720.mp4',\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1a8b5f05-12e8-4b53-b84c-e4555fc99022",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"multiprocessing.set_start_method('spawn')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "62f20b71-3ae4-4db9-9023-cdba818ef5e8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"videorag = VideoRAG(llm=ollama_config, working_dir=f\"./videorag-workdir/lexington\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7539e57b-11c2-42a9-b8f1-2363e0f561df",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# To build\n",
|
||||
"videorag.insert_video(video_path_list=video_paths)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6a938da7-5c33-44a3-a66e-a2e07636ce70",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# To query\n",
|
||||
"videorag.load_caption_model(debug=False)\n",
|
||||
"param = QueryParam(mode=\"videorag\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "15d51fb9-06e9-44e0-83ef-f3269c50480f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"What are the Lexington school construction options\"\n",
|
||||
"param.wo_reference = False\n",
|
||||
"response = videorag.query(query=query, param=param)\n",
|
||||
"print(response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4d517d22-a482-47a1-a4ac-89d0042b016f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python [conda env:videorag]",
|
||||
"language": "python",
|
||||
"name": "conda-env-videorag-py"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
243
videorag/_llm.py
243
videorag/_llm.py
@@ -1,6 +1,8 @@
|
||||
import numpy as np
|
||||
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError
|
||||
from ollama import AsyncClient
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
@@ -12,10 +14,11 @@ import os
|
||||
|
||||
from ._utils import compute_args_hash, wrap_embedding_func_with_attrs
|
||||
from .base import BaseKVStorage
|
||||
from ._utils import EmbeddingFunc
|
||||
|
||||
global_openai_async_client = None
|
||||
global_azure_openai_async_client = None
|
||||
|
||||
global_ollama_client = None
|
||||
|
||||
def get_openai_async_client_instance():
|
||||
global global_openai_async_client
|
||||
@@ -30,12 +33,62 @@ def get_azure_openai_async_client_instance():
|
||||
global_azure_openai_async_client = AsyncAzureOpenAI()
|
||||
return global_azure_openai_async_client
|
||||
|
||||
def get_ollama_async_client_instance():
|
||||
global global_ollama_client
|
||||
if global_ollama_client is None:
|
||||
# set OLLAMA_HOST or pass in host="http://127.0.0.1:11434"
|
||||
global_ollama_client = AsyncClient() # Adjust base URL if necessary
|
||||
return global_ollama_client
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
|
||||
)
|
||||
|
||||
# Setup LLM Configuration.
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
# To be set
|
||||
embedding_func_raw: callable
|
||||
embedding_model_name: str
|
||||
embedding_dim: int
|
||||
embedding_max_token_size: int
|
||||
embedding_batch_num: int
|
||||
embedding_func_max_async: int
|
||||
query_better_than_threshold: float
|
||||
|
||||
best_model_func_raw: callable
|
||||
best_model_name: str
|
||||
best_model_max_token_size: int
|
||||
best_model_max_async: int
|
||||
|
||||
cheap_model_func_raw: callable
|
||||
cheap_model_name: str
|
||||
cheap_model_max_token_size: int
|
||||
cheap_model_max_async: int
|
||||
|
||||
# Assigned in post init
|
||||
embedding_func: EmbeddingFunc = None
|
||||
best_model_func: callable = None
|
||||
cheap_model_func: callable = None
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
embedding_wrapper = wrap_embedding_func_with_attrs(
|
||||
embedding_dim = self.embedding_dim,
|
||||
max_token_size = self.embedding_max_token_size,
|
||||
model_name = self.embedding_model_name)
|
||||
self.embedding_func = embedding_wrapper(self.embedding_func_raw)
|
||||
self.best_model_func = lambda prompt, *args, **kwargs: self.best_model_func_raw(
|
||||
self.best_model_name, prompt, *args, **kwargs
|
||||
)
|
||||
|
||||
self.cheap_model_func = lambda prompt, *args, **kwargs: self.cheap_model_func_raw(
|
||||
self.cheap_model_name, prompt, *args, **kwargs
|
||||
)
|
||||
|
||||
##### OpenAI Configuration
|
||||
async def openai_complete_if_cache(
|
||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
@@ -65,43 +118,82 @@ async def openai_complete_if_cache(
|
||||
|
||||
|
||||
async def gpt_4o_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
model_name, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
return await openai_complete_if_cache(
|
||||
"gpt-4o",
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def gpt_4o_mini_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
model_name, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
return await openai_complete_if_cache(
|
||||
"gpt-4o-mini",
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
|
||||
)
|
||||
async def openai_embedding(texts: list[str]) -> np.ndarray:
|
||||
async def openai_embedding(model_name: str, texts: list[str]) -> np.ndarray:
|
||||
openai_async_client = get_openai_async_client_instance()
|
||||
response = await openai_async_client.embeddings.create(
|
||||
model="text-embedding-3-small", input=texts, encoding_format="float"
|
||||
model=model_name, input=texts, encoding_format="float"
|
||||
)
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
|
||||
openai_config = LLMConfig(
|
||||
embedding_func_raw = openai_embedding,
|
||||
embedding_model_name = "text-embedding-3-small",
|
||||
embedding_dim = 1536,
|
||||
embedding_max_token_size = 8192,
|
||||
embedding_batch_num = 32,
|
||||
embedding_func_max_async = 16,
|
||||
query_better_than_threshold = 0.2,
|
||||
|
||||
# LLM
|
||||
best_model_func_raw = gpt_4o_complete,
|
||||
best_model_name = "gpt-4o",
|
||||
best_model_max_token_size = 32768,
|
||||
best_model_max_async = 16,
|
||||
|
||||
cheap_model_func_raw = gpt_4o_mini_complete,
|
||||
cheap_model_name = "gpt-4o-mini",
|
||||
cheap_model_max_token_size = 32768,
|
||||
cheap_model_max_async = 16
|
||||
)
|
||||
|
||||
openai_4o_mini_config = LLMConfig(
|
||||
embedding_func_raw = openai_embedding,
|
||||
embedding_model_name = "text-embedding-3-small",
|
||||
embedding_dim = 1536,
|
||||
embedding_max_token_size = 8192,
|
||||
embedding_batch_num = 32,
|
||||
embedding_func_max_async = 16,
|
||||
query_better_than_threshold = 0.2,
|
||||
|
||||
# LLM
|
||||
best_model_func_raw = gpt_4o_mini_complete,
|
||||
best_model_name = "gpt-4o-mini",
|
||||
best_model_max_token_size = 32768,
|
||||
best_model_max_async = 16,
|
||||
|
||||
cheap_model_func_raw = gpt_4o_mini_complete,
|
||||
cheap_model_name = "gpt-4o-mini",
|
||||
cheap_model_max_token_size = 32768,
|
||||
cheap_model_max_async = 16
|
||||
)
|
||||
|
||||
###### Azure OpenAI Configuration
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
@@ -141,10 +233,10 @@ async def azure_openai_complete_if_cache(
|
||||
|
||||
|
||||
async def azure_gpt_4o_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
model_name, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
return await azure_openai_complete_if_cache(
|
||||
"gpt-4o",
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
@@ -153,10 +245,10 @@ async def azure_gpt_4o_complete(
|
||||
|
||||
|
||||
async def azure_gpt_4o_mini_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
model_name, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
return await azure_openai_complete_if_cache(
|
||||
"gpt-4o-mini",
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
@@ -164,15 +256,132 @@ async def azure_gpt_4o_mini_complete(
|
||||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
|
||||
)
|
||||
async def azure_openai_embedding(texts: list[str]) -> np.ndarray:
|
||||
async def azure_openai_embedding(model_name: str, texts: list[str]) -> np.ndarray:
|
||||
azure_openai_client = get_azure_openai_async_client_instance()
|
||||
response = await azure_openai_client.embeddings.create(
|
||||
model="text-embedding-3-small", input=texts, encoding_format="float"
|
||||
model=model_name, input=texts, encoding_format="float"
|
||||
)
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
|
||||
|
||||
azure_openai_config = LLMConfig(
|
||||
embedding_func_raw = azure_openai_embedding,
|
||||
embedding_model_name = "text-embedding-3-small",
|
||||
embedding_dim = 1536,
|
||||
embedding_max_token_size = 8192,
|
||||
embedding_batch_num = 32,
|
||||
embedding_func_max_async = 16,
|
||||
query_better_than_threshold = 0.2,
|
||||
|
||||
best_model_func_raw = azure_gpt_4o_complete,
|
||||
best_model_name = "gpt-4o",
|
||||
best_model_max_token_size = 32768,
|
||||
best_model_max_async = 16,
|
||||
|
||||
cheap_model_func_raw = azure_gpt_4o_mini_complete,
|
||||
cheap_model_name = "gpt-4o-mini",
|
||||
cheap_model_max_token_size = 32768,
|
||||
cheap_model_max_async = 16
|
||||
)
|
||||
|
||||
|
||||
###### Ollama configuration
|
||||
|
||||
async def ollama_complete_if_cache(
|
||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
# Initialize the Ollama client
|
||||
ollama_client = get_ollama_async_client_instance()
|
||||
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
# Send the request to Ollama
|
||||
response = await ollama_client.chat(
|
||||
model=model,
|
||||
messages=messages
|
||||
)
|
||||
# print(messages)
|
||||
# print(response['message']['content'])
|
||||
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{args_hash: {"return": response['message']['content'], "model": model}}
|
||||
)
|
||||
await hashing_kv.index_done_callback()
|
||||
|
||||
return response['message']['content']
|
||||
|
||||
|
||||
async def ollama_complete(model_name, prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
|
||||
return await ollama_complete_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages
|
||||
)
|
||||
|
||||
async def ollama_mini_complete(model_name, prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
|
||||
return await ollama_complete_if_cache(
|
||||
# "deepseek-r1:latest", # For now select your model
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages
|
||||
)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
|
||||
)
|
||||
async def ollama_embedding(model_name: str, texts: list[str]) -> np.ndarray:
|
||||
# Initialize the Ollama client
|
||||
ollama_client = get_ollama_async_client_instance()
|
||||
|
||||
# Send the request to Ollama for embeddings
|
||||
response = await ollama_client.embed(
|
||||
model=model_name,
|
||||
input=texts
|
||||
)
|
||||
|
||||
# Extract embeddings from the response
|
||||
embeddings = response['embeddings']
|
||||
|
||||
return np.array(embeddings)
|
||||
|
||||
ollama_config = LLMConfig(
|
||||
embedding_func_raw = ollama_embedding,
|
||||
embedding_model_name = "nomic-embed-text",
|
||||
embedding_dim = 768,
|
||||
embedding_max_token_size=8192,
|
||||
embedding_batch_num = 1,
|
||||
embedding_func_max_async = 1,
|
||||
query_better_than_threshold = 0.2,
|
||||
best_model_func_raw = ollama_complete ,
|
||||
best_model_name = "gemma2:latest", # need to be a solid instruct model
|
||||
best_model_max_token_size = 32768,
|
||||
best_model_max_async = 1,
|
||||
cheap_model_func_raw = ollama_mini_complete,
|
||||
cheap_model_name = "olmo2",
|
||||
cheap_model_max_token_size = 32768,
|
||||
cheap_model_max_async = 1
|
||||
)
|
||||
|
||||
@@ -183,8 +183,8 @@ async def _handle_entity_relation_summary(
|
||||
description: str,
|
||||
global_config: dict,
|
||||
) -> str:
|
||||
use_llm_func: callable = global_config["cheap_model_func"]
|
||||
llm_max_tokens = global_config["cheap_model_max_token_size"]
|
||||
use_llm_func: callable = global_config["llm"]["cheap_model_func"]
|
||||
llm_max_tokens = global_config["llm"]["cheap_model_max_token_size"]
|
||||
tiktoken_model_name = global_config["tiktoken_model_name"]
|
||||
summary_max_tokens = global_config["entity_summary_to_max_tokens"]
|
||||
|
||||
@@ -359,7 +359,7 @@ async def extract_entities(
|
||||
entity_vdb: BaseVectorStorage,
|
||||
global_config: dict,
|
||||
) -> Union[BaseGraphStorage, None]:
|
||||
use_llm_func: callable = global_config["best_model_func"]
|
||||
use_llm_func: callable = global_config["llm"]["best_model_func"]
|
||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||
|
||||
ordered_chunks = list(chunks.items())
|
||||
@@ -549,7 +549,7 @@ async def _refine_entity_retrieval_query(
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
):
|
||||
use_llm_func: callable = global_config["cheap_model_func"]
|
||||
use_llm_func: callable = global_config["llm"]["cheap_model_func"]
|
||||
query_rewrite_prompt = PROMPTS["query_rewrite_for_entity_retrieval"]
|
||||
query_rewrite_prompt = query_rewrite_prompt.format(input_text=query)
|
||||
final_result = await use_llm_func(query_rewrite_prompt)
|
||||
@@ -560,7 +560,7 @@ async def _refine_visual_retrieval_query(
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
):
|
||||
use_llm_func: callable = global_config["cheap_model_func"]
|
||||
use_llm_func: callable = global_config["llm"]["cheap_model_func"]
|
||||
query_rewrite_prompt = PROMPTS["query_rewrite_for_visual_retrieval"]
|
||||
query_rewrite_prompt = query_rewrite_prompt.format(input_text=query)
|
||||
final_result = await use_llm_func(query_rewrite_prompt)
|
||||
@@ -571,7 +571,7 @@ async def _extract_keywords_query(
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
):
|
||||
use_llm_func: callable = global_config["cheap_model_func"]
|
||||
use_llm_func: callable = global_config["llm"]["cheap_model_func"]
|
||||
keywords_prompt = PROMPTS["keywords_extraction"]
|
||||
keywords_prompt = keywords_prompt.format(input_text=query)
|
||||
final_result = await use_llm_func(keywords_prompt)
|
||||
@@ -591,7 +591,7 @@ async def videorag_query(
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
) -> str:
|
||||
use_model_func = global_config["best_model_func"]
|
||||
use_model_func = global_config["llm"]["best_model_func"]
|
||||
query = query
|
||||
|
||||
# naive chunks
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -21,7 +21,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
self._max_batch_size = self.global_config["llm"]["embedding_batch_num"]
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
)
|
||||
@@ -142,4 +142,4 @@ class NanoVectorDBVideoSegmentStorage(BaseVectorStorage):
|
||||
return results
|
||||
|
||||
async def index_done_callback(self):
|
||||
self._client.save()
|
||||
self._client.save()
|
||||
|
||||
@@ -154,10 +154,23 @@ def clean_str(input: Any) -> str:
|
||||
class EmbeddingFunc:
|
||||
embedding_dim: int
|
||||
max_token_size: int
|
||||
model_name: str
|
||||
func: callable
|
||||
|
||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||
return await self.func(*args, **kwargs)
|
||||
# Had to fix this as the embedding function took only one named argument put it's passed in
|
||||
# positionally, now we need to pass both
|
||||
kwargs['model_name'] = self.model_name
|
||||
|
||||
# If there are positional arguments, convert them to keyword arguments
|
||||
if args:
|
||||
# Assuming the first positional argument is always 'texts'
|
||||
if len(args) == 1 and isinstance(args[0], list):
|
||||
kwargs['texts'] = args[0]
|
||||
else:
|
||||
raise ValueError("Unexpected positional arguments. Expected a single list of texts")
|
||||
# Call the function with the updated keyword arguments
|
||||
return await self.func(**kwargs)
|
||||
|
||||
|
||||
# Decorators ------------------------------------------------------------------------
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -13,12 +13,10 @@ import tiktoken
|
||||
|
||||
|
||||
from ._llm import (
|
||||
gpt_4o_complete,
|
||||
gpt_4o_mini_complete,
|
||||
openai_embedding,
|
||||
azure_gpt_4o_complete,
|
||||
azure_openai_embedding,
|
||||
azure_gpt_4o_mini_complete,
|
||||
LLMConfig,
|
||||
openai_config,
|
||||
azure_openai_config,
|
||||
ollama_config
|
||||
)
|
||||
from ._op import (
|
||||
chunking_by_video_segments,
|
||||
@@ -36,6 +34,7 @@ from ._utils import (
|
||||
EmbeddingFunc,
|
||||
compute_mdhash_id,
|
||||
limit_async_func_call,
|
||||
wrap_embedding_func_with_attrs,
|
||||
convert_response_to_json,
|
||||
always_get_an_event_loop,
|
||||
logger,
|
||||
@@ -98,21 +97,9 @@ class VideoRAG:
|
||||
entity_extract_max_gleaning: int = 1
|
||||
entity_summary_to_max_tokens: int = 500
|
||||
|
||||
# text embedding
|
||||
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
||||
embedding_batch_num: int = 32
|
||||
embedding_func_max_async: int = 16
|
||||
query_better_than_threshold: float = 0.2
|
||||
|
||||
# LLM
|
||||
using_azure_openai: bool = False
|
||||
best_model_func: callable = gpt_4o_mini_complete
|
||||
best_model_max_token_size: int = 32768
|
||||
best_model_max_async: int = 16
|
||||
cheap_model_func: callable = gpt_4o_mini_complete
|
||||
cheap_model_max_token_size: int = 32768
|
||||
cheap_model_max_async: int = 16
|
||||
|
||||
# Change to your LLM provider
|
||||
llm: LLMConfig = field(default_factory=openai_config)
|
||||
|
||||
# entity extraction
|
||||
entity_extraction_func: callable = extract_entities
|
||||
|
||||
@@ -143,18 +130,6 @@ class VideoRAG:
|
||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
||||
logger.debug(f"VideoRAG init with param:\n\n {_print_config}\n")
|
||||
|
||||
if self.using_azure_openai:
|
||||
# If there's no OpenAI API key, use Azure OpenAI
|
||||
if self.best_model_func == gpt_4o_complete:
|
||||
self.best_model_func = azure_gpt_4o_complete
|
||||
if self.cheap_model_func == gpt_4o_mini_complete:
|
||||
self.cheap_model_func = azure_gpt_4o_mini_complete
|
||||
if self.embedding_func == openai_embedding:
|
||||
self.embedding_func = azure_openai_embedding
|
||||
logger.info(
|
||||
"Switched the default openai funcs to Azure OpenAI if you didn't set any of it"
|
||||
)
|
||||
|
||||
if not os.path.exists(self.working_dir) and self.always_create_working_dir:
|
||||
logger.info(f"Creating working directory {self.working_dir}")
|
||||
os.makedirs(self.working_dir)
|
||||
@@ -183,9 +158,10 @@ class VideoRAG:
|
||||
namespace="chunk_entity_relation", global_config=asdict(self)
|
||||
)
|
||||
|
||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
||||
self.embedding_func
|
||||
)
|
||||
self.embedding_func = limit_async_func_call(self.llm.embedding_func_max_async)(wrap_embedding_func_with_attrs(
|
||||
embedding_dim = self.llm.embedding_dim,
|
||||
max_token_size = self.llm.embedding_max_token_size,
|
||||
model_name = self.llm.embedding_model_name)(self.llm.embedding_func))
|
||||
self.entities_vdb = (
|
||||
self.vector_db_storage_cls(
|
||||
namespace="entities",
|
||||
@@ -214,11 +190,11 @@ class VideoRAG:
|
||||
)
|
||||
)
|
||||
|
||||
self.best_model_func = limit_async_func_call(self.best_model_max_async)(
|
||||
partial(self.best_model_func, hashing_kv=self.llm_response_cache)
|
||||
self.best_model_func = limit_async_func_call(self.llm.best_model_max_async)(
|
||||
partial(self.llm.best_model_func, hashing_kv=self.llm_response_cache)
|
||||
)
|
||||
self.cheap_model_func = limit_async_func_call(self.cheap_model_max_async)(
|
||||
partial(self.cheap_model_func, hashing_kv=self.llm_response_cache)
|
||||
self.cheap_model_func = limit_async_func_call(self.llm.cheap_model_max_async)(
|
||||
partial(self.llm.cheap_model_func, hashing_kv=self.llm_response_cache)
|
||||
)
|
||||
|
||||
def insert_video(self, video_path_list=None):
|
||||
|
||||
Reference in New Issue
Block a user