Merge pull request #4 from geraldthewes/ollama_support

Ollama support
This commit is contained in:
Xubin Ren
2025-02-25 21:37:39 +08:00
committed by GitHub
30 changed files with 469 additions and 80 deletions

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
# Byte-compiled / optimized / DLL files
videorag/__pycache__/
*.py[cod]
*$py.class

View File

@@ -114,7 +114,7 @@ logging.getLogger("httpx").setLevel(logging.WARNING)
# Please enter your openai key
os.environ["OPENAI_API_KEY"] = ""
from videorag._llm import *
from videorag._llm import openai_4o_mini_config
from videorag import VideoRAG, QueryParam
@@ -127,7 +127,7 @@ if __name__ == '__main__':
'movies/Iron-Man.mp4',
'movies/Spider-Man.mkv',
]
videorag = VideoRAG(cheap_model_func=gpt_4o_mini_complete, best_model_func=gpt_4o_mini_complete, working_dir=f"./videorag-workdir")
videorag = VideoRAG(llm=openai_4o_mini_config, working_dir=f"./videorag-workdir")
videorag.insert_video(video_path_list=video_paths)
```
@@ -156,7 +156,7 @@ if __name__ == '__main__':
# if param.wo_reference = False, VideoRAG will add reference to video clips in the response
param.wo_reference = True
videorag = VideoRAG(cheap_model_func=gpt_4o_mini_complete, best_model_func=gpt_4o_mini_complete, working_dir=f"./videorag-workdir")
videorag = videorag = VideoRAG(llm=openai_4o_mini_config, working_dir=f"./videorag-workdir")
videorag.load_caption_model(debug=False)
response = videorag.query(query=query, param=param)
print(response)
@@ -187,7 +187,7 @@ sh download.sh # downloading videos
Then, you can run the following example command to process and answer queries for LongerVideos with VideoRAG:
```shell
# Please enter your openai_key in line 18 at first
# Please enter your openai_key in line 22 at first
python videorag_experiment.py --collection 4-rag-lecture --cuda 0
```
@@ -244,6 +244,44 @@ python batch_winrate_quant_download.py
python batch_winrate_quant_calculate.py
```
### Ollama Support
This project also supports ollama. To use, edit the ollama_config in [_llm.py](VideoRAG/videorag/_llm.py).
Adjust the paramters of the models being used
```
ollama_config = LLMConfig(
embedding_func_raw = ollama_embedding,
embedding_model_name = "nomic-embed-text",
embedding_dim = 768,
embedding_max_token_size=8192,
embedding_batch_num = 1,
embedding_func_max_async = 1,
query_better_than_threshold = 0.2,
best_model_func_raw = ollama_complete ,
best_model_name = "gemma2:latest", # need to be a solid instruct model
best_model_max_token_size = 32768,
best_model_max_async = 1,
cheap_model_func_raw = ollama_mini_complete,
cheap_model_name = "olmo2",
cheap_model_max_token_size = 32768,
cheap_model_max_async = 1
)
```
And specify the config when creating your VideoRag instance
### Jupyter Notebook
To test the solution on a single video, just load the notebook in the [notebook folder](VideoRAG/nodebooks) and
update the paramters to fit your situation.
YouTube video for example can be downloaded as follows:
```
yt-dlp -o "%(id)s.%(ext)s" -S "res:720" https://www.youtube.com/live/DPa2iRgzadM?si=8cf8WbYtqiglrwtN -P .
```
## Citation
If you find this work is helpful to your research, please consider citing our paper:
```bibtex

View File

@@ -3,41 +3,66 @@ import json
import logging
import warnings
import multiprocessing
import sys
warnings.filterwarnings("ignore")
logging.getLogger("httpx").setLevel(logging.WARNING)
# Add the parent directory of 'videorag' to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
parser = argparse.ArgumentParser(description="Set sub-category and CUDA device.")
parser.add_argument('--collection', type=str, default='4-rag-lecture')
parser.add_argument('--cuda', type=str, default='0')
args = parser.parse_args()
sub_category = args.sub_category
sub_category = args.collection
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
os.environ["OPENAI_API_KEY"] = ""
from videorag._llm import *
from videorag import VideoRAG, QueryParam
from videorag.videorag import VideoRAG, QueryParam
longervideos_llm_config = LLMConfig(
embedding_func_raw = openai_embedding,
embedding_model_name = "text-embedding-3-small",
embedding_dim = 1536,
embedding_max_token_size = 8192,
embedding_batch_num = 32,
embedding_func_max_async = 16,
query_better_than_threshold = 0.2,
# LLM (we utilize gpt-4o-mini for all experiments)
best_model_func_raw = gpt_4o_mini_complete,
best_model_name = "gpt-4o-mini",
best_model_max_token_size = 32768,
best_model_max_async = 16,
cheap_model_func_raw = gpt_4o_mini_complete,
cheap_model_name = "gpt-4o-mini",
cheap_model_max_token_size = 32768,
cheap_model_max_async = 16
)
if __name__ == '__main__':
multiprocessing.set_start_method('spawn')
## learn
video_base_path = f'./{sub_category}/videos/'
video_base_path = f'./longervideos/{sub_category}/videos/'
video_files = sorted(os.listdir(video_base_path))
video_paths = [os.path.join(video_base_path, f) for f in video_files]
videorag = VideoRAG(cheap_model_func=gpt_4o_mini_complete, best_model_func=gpt_4o_mini_complete, working_dir=f"./videorag-workdir/{sub_category}")
videorag = VideoRAG(llm=longervideos_llm_config, working_dir=f"./longervideos/videorag-workdir/{sub_category}")
videorag.insert_video(video_path_list=video_paths)
## inference
with open(f'./dataset.json', 'r') as f:
with open(f'./longervideos/dataset.json', 'r') as f:
longervideos = json.load(f)
videorag = VideoRAG(cheap_model_func=gpt_4o_mini_complete, best_model_func=gpt_4o_mini_complete, working_dir=f"./videorag-workdir/{sub_category}")
videorag = VideoRAG(llm=longervideos_llm_config, working_dir=f"./longervideos/videorag-workdir/{sub_category}")
videorag.load_caption_model(debug=False)
answer_folder = f'./videorag-answers/{sub_category}'
answer_folder = f'./longervideos/videorag-answers/{sub_category}'
os.makedirs(answer_folder, exist_ok=True)
collection_id = sub_category.split('-')[0]
@@ -51,5 +76,5 @@ if __name__ == '__main__':
response = videorag.query(query=query, param=param)
print(response)
with open(os.path.join(answer_folder, f'/answer_{query_id}.md'), 'w') as f:
f.write(response)
with open(os.path.join(answer_folder, f'answer_{query_id}.md'), 'w') as f:
f.write(response)

124
notesbooks/videorag.ipynb Normal file
View File

@@ -0,0 +1,124 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "c1e12bf6-a8ae-4fec-9fbd-99ea74bcc563",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import logging\n",
"import warnings\n",
"import multiprocessing\n",
"import nest_asyncio\n",
" \n",
"nest_asyncio.apply()\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"logging.getLogger(\"httpx\").setLevel(logging.WARNING)\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n",
"\n",
"from videorag._llm import openai_config, openai_4o_mini_config, azure_openai_config, ollama_config\n",
"from videorag import VideoRAG, QueryParam\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f04cee24-fd8e-41dc-93ea-888229d2a9af",
"metadata": {},
"outputs": [],
"source": [
"video_paths = [\n",
" '/mnt/data3/AI/software/VideoRAG/Lexington/GMT20241112-164602_Recording_gallery_1280x720.mp4',\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a8b5f05-12e8-4b53-b84c-e4555fc99022",
"metadata": {},
"outputs": [],
"source": [
"multiprocessing.set_start_method('spawn')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "62f20b71-3ae4-4db9-9023-cdba818ef5e8",
"metadata": {},
"outputs": [],
"source": [
"videorag = VideoRAG(llm=ollama_config, working_dir=f\"./videorag-workdir/lexington\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7539e57b-11c2-42a9-b8f1-2363e0f561df",
"metadata": {},
"outputs": [],
"source": [
"# To build\n",
"videorag.insert_video(video_path_list=video_paths)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a938da7-5c33-44a3-a66e-a2e07636ce70",
"metadata": {},
"outputs": [],
"source": [
"# To query\n",
"videorag.load_caption_model(debug=False)\n",
"param = QueryParam(mode=\"videorag\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15d51fb9-06e9-44e0-83ef-f3269c50480f",
"metadata": {},
"outputs": [],
"source": [
"query = \"What are the Lexington school construction options\"\n",
"param.wo_reference = False\n",
"response = videorag.query(query=query, param=param)\n",
"print(response)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4d517d22-a482-47a1-a4ac-89d0042b016f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:videorag]",
"language": "python",
"name": "conda-env-videorag-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -1,6 +1,8 @@
import numpy as np
from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError
from ollama import AsyncClient
from dataclasses import asdict, dataclass, field
from tenacity import (
retry,
@@ -12,10 +14,11 @@ import os
from ._utils import compute_args_hash, wrap_embedding_func_with_attrs
from .base import BaseKVStorage
from ._utils import EmbeddingFunc
global_openai_async_client = None
global_azure_openai_async_client = None
global_ollama_client = None
def get_openai_async_client_instance():
global global_openai_async_client
@@ -30,12 +33,62 @@ def get_azure_openai_async_client_instance():
global_azure_openai_async_client = AsyncAzureOpenAI()
return global_azure_openai_async_client
def get_ollama_async_client_instance():
global global_ollama_client
if global_ollama_client is None:
# set OLLAMA_HOST or pass in host="http://127.0.0.1:11434"
global_ollama_client = AsyncClient() # Adjust base URL if necessary
return global_ollama_client
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
# Setup LLM Configuration.
@dataclass
class LLMConfig:
# To be set
embedding_func_raw: callable
embedding_model_name: str
embedding_dim: int
embedding_max_token_size: int
embedding_batch_num: int
embedding_func_max_async: int
query_better_than_threshold: float
best_model_func_raw: callable
best_model_name: str
best_model_max_token_size: int
best_model_max_async: int
cheap_model_func_raw: callable
cheap_model_name: str
cheap_model_max_token_size: int
cheap_model_max_async: int
# Assigned in post init
embedding_func: EmbeddingFunc = None
best_model_func: callable = None
cheap_model_func: callable = None
def __post_init__(self):
embedding_wrapper = wrap_embedding_func_with_attrs(
embedding_dim = self.embedding_dim,
max_token_size = self.embedding_max_token_size,
model_name = self.embedding_model_name)
self.embedding_func = embedding_wrapper(self.embedding_func_raw)
self.best_model_func = lambda prompt, *args, **kwargs: self.best_model_func_raw(
self.best_model_name, prompt, *args, **kwargs
)
self.cheap_model_func = lambda prompt, *args, **kwargs: self.cheap_model_func_raw(
self.cheap_model_name, prompt, *args, **kwargs
)
##### OpenAI Configuration
async def openai_complete_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -65,43 +118,82 @@ async def openai_complete_if_cache(
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
model_name, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"gpt-4o",
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def gpt_4o_mini_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
model_name, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"gpt-4o-mini",
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def openai_embedding(texts: list[str]) -> np.ndarray:
async def openai_embedding(model_name: str, texts: list[str]) -> np.ndarray:
openai_async_client = get_openai_async_client_instance()
response = await openai_async_client.embeddings.create(
model="text-embedding-3-small", input=texts, encoding_format="float"
model=model_name, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])
openai_config = LLMConfig(
embedding_func_raw = openai_embedding,
embedding_model_name = "text-embedding-3-small",
embedding_dim = 1536,
embedding_max_token_size = 8192,
embedding_batch_num = 32,
embedding_func_max_async = 16,
query_better_than_threshold = 0.2,
# LLM
best_model_func_raw = gpt_4o_complete,
best_model_name = "gpt-4o",
best_model_max_token_size = 32768,
best_model_max_async = 16,
cheap_model_func_raw = gpt_4o_mini_complete,
cheap_model_name = "gpt-4o-mini",
cheap_model_max_token_size = 32768,
cheap_model_max_async = 16
)
openai_4o_mini_config = LLMConfig(
embedding_func_raw = openai_embedding,
embedding_model_name = "text-embedding-3-small",
embedding_dim = 1536,
embedding_max_token_size = 8192,
embedding_batch_num = 32,
embedding_func_max_async = 16,
query_better_than_threshold = 0.2,
# LLM
best_model_func_raw = gpt_4o_mini_complete,
best_model_name = "gpt-4o-mini",
best_model_max_token_size = 32768,
best_model_max_async = 16,
cheap_model_func_raw = gpt_4o_mini_complete,
cheap_model_name = "gpt-4o-mini",
cheap_model_max_token_size = 32768,
cheap_model_max_async = 16
)
###### Azure OpenAI Configuration
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -141,10 +233,10 @@ async def azure_openai_complete_if_cache(
async def azure_gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
model_name, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await azure_openai_complete_if_cache(
"gpt-4o",
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
@@ -153,10 +245,10 @@ async def azure_gpt_4o_complete(
async def azure_gpt_4o_mini_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
model_name, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await azure_openai_complete_if_cache(
"gpt-4o-mini",
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
@@ -164,15 +256,132 @@ async def azure_gpt_4o_mini_complete(
)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def azure_openai_embedding(texts: list[str]) -> np.ndarray:
async def azure_openai_embedding(model_name: str, texts: list[str]) -> np.ndarray:
azure_openai_client = get_azure_openai_async_client_instance()
response = await azure_openai_client.embeddings.create(
model="text-embedding-3-small", input=texts, encoding_format="float"
model=model_name, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])
azure_openai_config = LLMConfig(
embedding_func_raw = azure_openai_embedding,
embedding_model_name = "text-embedding-3-small",
embedding_dim = 1536,
embedding_max_token_size = 8192,
embedding_batch_num = 32,
embedding_func_max_async = 16,
query_better_than_threshold = 0.2,
best_model_func_raw = azure_gpt_4o_complete,
best_model_name = "gpt-4o",
best_model_max_token_size = 32768,
best_model_max_async = 16,
cheap_model_func_raw = azure_gpt_4o_mini_complete,
cheap_model_name = "gpt-4o-mini",
cheap_model_max_token_size = 32768,
cheap_model_max_async = 16
)
###### Ollama configuration
async def ollama_complete_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
# Initialize the Ollama client
ollama_client = get_ollama_async_client_instance()
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# Send the request to Ollama
response = await ollama_client.chat(
model=model,
messages=messages
)
# print(messages)
# print(response['message']['content'])
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response['message']['content'], "model": model}}
)
await hashing_kv.index_done_callback()
return response['message']['content']
async def ollama_complete(model_name, prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
return await ollama_complete_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages
)
async def ollama_mini_complete(model_name, prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
return await ollama_complete_if_cache(
# "deepseek-r1:latest", # For now select your model
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages
)
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def ollama_embedding(model_name: str, texts: list[str]) -> np.ndarray:
# Initialize the Ollama client
ollama_client = get_ollama_async_client_instance()
# Send the request to Ollama for embeddings
response = await ollama_client.embed(
model=model_name,
input=texts
)
# Extract embeddings from the response
embeddings = response['embeddings']
return np.array(embeddings)
ollama_config = LLMConfig(
embedding_func_raw = ollama_embedding,
embedding_model_name = "nomic-embed-text",
embedding_dim = 768,
embedding_max_token_size=8192,
embedding_batch_num = 1,
embedding_func_max_async = 1,
query_better_than_threshold = 0.2,
best_model_func_raw = ollama_complete ,
best_model_name = "gemma2:latest", # need to be a solid instruct model
best_model_max_token_size = 32768,
best_model_max_async = 1,
cheap_model_func_raw = ollama_mini_complete,
cheap_model_name = "olmo2",
cheap_model_max_token_size = 32768,
cheap_model_max_async = 1
)

View File

@@ -183,8 +183,8 @@ async def _handle_entity_relation_summary(
description: str,
global_config: dict,
) -> str:
use_llm_func: callable = global_config["cheap_model_func"]
llm_max_tokens = global_config["cheap_model_max_token_size"]
use_llm_func: callable = global_config["llm"]["cheap_model_func"]
llm_max_tokens = global_config["llm"]["cheap_model_max_token_size"]
tiktoken_model_name = global_config["tiktoken_model_name"]
summary_max_tokens = global_config["entity_summary_to_max_tokens"]
@@ -359,7 +359,7 @@ async def extract_entities(
entity_vdb: BaseVectorStorage,
global_config: dict,
) -> Union[BaseGraphStorage, None]:
use_llm_func: callable = global_config["best_model_func"]
use_llm_func: callable = global_config["llm"]["best_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
ordered_chunks = list(chunks.items())
@@ -549,7 +549,7 @@ async def _refine_entity_retrieval_query(
query_param: QueryParam,
global_config: dict,
):
use_llm_func: callable = global_config["cheap_model_func"]
use_llm_func: callable = global_config["llm"]["cheap_model_func"]
query_rewrite_prompt = PROMPTS["query_rewrite_for_entity_retrieval"]
query_rewrite_prompt = query_rewrite_prompt.format(input_text=query)
final_result = await use_llm_func(query_rewrite_prompt)
@@ -560,7 +560,7 @@ async def _refine_visual_retrieval_query(
query_param: QueryParam,
global_config: dict,
):
use_llm_func: callable = global_config["cheap_model_func"]
use_llm_func: callable = global_config["llm"]["cheap_model_func"]
query_rewrite_prompt = PROMPTS["query_rewrite_for_visual_retrieval"]
query_rewrite_prompt = query_rewrite_prompt.format(input_text=query)
final_result = await use_llm_func(query_rewrite_prompt)
@@ -571,7 +571,7 @@ async def _extract_keywords_query(
query_param: QueryParam,
global_config: dict,
):
use_llm_func: callable = global_config["cheap_model_func"]
use_llm_func: callable = global_config["llm"]["cheap_model_func"]
keywords_prompt = PROMPTS["keywords_extraction"]
keywords_prompt = keywords_prompt.format(input_text=query)
final_result = await use_llm_func(keywords_prompt)
@@ -591,7 +591,7 @@ async def videorag_query(
query_param: QueryParam,
global_config: dict,
) -> str:
use_model_func = global_config["best_model_func"]
use_model_func = global_config["llm"]["best_model_func"]
query = query
# naive chunks

View File

@@ -21,7 +21,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
self._max_batch_size = self.global_config["llm"]["embedding_batch_num"]
self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name
)
@@ -142,4 +142,4 @@ class NanoVectorDBVideoSegmentStorage(BaseVectorStorage):
return results
async def index_done_callback(self):
self._client.save()
self._client.save()

View File

@@ -154,10 +154,23 @@ def clean_str(input: Any) -> str:
class EmbeddingFunc:
embedding_dim: int
max_token_size: int
model_name: str
func: callable
async def __call__(self, *args, **kwargs) -> np.ndarray:
return await self.func(*args, **kwargs)
# Had to fix this as the embedding function took only one named argument put it's passed in
# positionally, now we need to pass both
kwargs['model_name'] = self.model_name
# If there are positional arguments, convert them to keyword arguments
if args:
# Assuming the first positional argument is always 'texts'
if len(args) == 1 and isinstance(args[0], list):
kwargs['texts'] = args[0]
else:
raise ValueError("Unexpected positional arguments. Expected a single list of texts")
# Call the function with the updated keyword arguments
return await self.func(**kwargs)
# Decorators ------------------------------------------------------------------------

View File

@@ -13,12 +13,10 @@ import tiktoken
from ._llm import (
gpt_4o_complete,
gpt_4o_mini_complete,
openai_embedding,
azure_gpt_4o_complete,
azure_openai_embedding,
azure_gpt_4o_mini_complete,
LLMConfig,
openai_config,
azure_openai_config,
ollama_config
)
from ._op import (
chunking_by_video_segments,
@@ -36,6 +34,7 @@ from ._utils import (
EmbeddingFunc,
compute_mdhash_id,
limit_async_func_call,
wrap_embedding_func_with_attrs,
convert_response_to_json,
always_get_an_event_loop,
logger,
@@ -98,21 +97,9 @@ class VideoRAG:
entity_extract_max_gleaning: int = 1
entity_summary_to_max_tokens: int = 500
# text embedding
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
embedding_batch_num: int = 32
embedding_func_max_async: int = 16
query_better_than_threshold: float = 0.2
# LLM
using_azure_openai: bool = False
best_model_func: callable = gpt_4o_mini_complete
best_model_max_token_size: int = 32768
best_model_max_async: int = 16
cheap_model_func: callable = gpt_4o_mini_complete
cheap_model_max_token_size: int = 32768
cheap_model_max_async: int = 16
# Change to your LLM provider
llm: LLMConfig = field(default_factory=openai_config)
# entity extraction
entity_extraction_func: callable = extract_entities
@@ -143,18 +130,6 @@ class VideoRAG:
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"VideoRAG init with param:\n\n {_print_config}\n")
if self.using_azure_openai:
# If there's no OpenAI API key, use Azure OpenAI
if self.best_model_func == gpt_4o_complete:
self.best_model_func = azure_gpt_4o_complete
if self.cheap_model_func == gpt_4o_mini_complete:
self.cheap_model_func = azure_gpt_4o_mini_complete
if self.embedding_func == openai_embedding:
self.embedding_func = azure_openai_embedding
logger.info(
"Switched the default openai funcs to Azure OpenAI if you didn't set any of it"
)
if not os.path.exists(self.working_dir) and self.always_create_working_dir:
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
@@ -183,9 +158,10 @@ class VideoRAG:
namespace="chunk_entity_relation", global_config=asdict(self)
)
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
)
self.embedding_func = limit_async_func_call(self.llm.embedding_func_max_async)(wrap_embedding_func_with_attrs(
embedding_dim = self.llm.embedding_dim,
max_token_size = self.llm.embedding_max_token_size,
model_name = self.llm.embedding_model_name)(self.llm.embedding_func))
self.entities_vdb = (
self.vector_db_storage_cls(
namespace="entities",
@@ -214,11 +190,11 @@ class VideoRAG:
)
)
self.best_model_func = limit_async_func_call(self.best_model_max_async)(
partial(self.best_model_func, hashing_kv=self.llm_response_cache)
self.best_model_func = limit_async_func_call(self.llm.best_model_max_async)(
partial(self.llm.best_model_func, hashing_kv=self.llm_response_cache)
)
self.cheap_model_func = limit_async_func_call(self.cheap_model_max_async)(
partial(self.cheap_model_func, hashing_kv=self.llm_response_cache)
self.cheap_model_func = limit_async_func_call(self.llm.cheap_model_max_async)(
partial(self.llm.cheap_model_func, hashing_kv=self.llm_response_cache)
)
def insert_video(self, video_path_list=None):