mirror of
https://github.com/HKUDS/VideoRAG.git
synced 2025-05-11 03:54:36 +03:00
Initial changes to support ollama
This commit is contained in:
@@ -15,7 +15,7 @@ from .base import BaseKVStorage
|
||||
|
||||
global_openai_async_client = None
|
||||
global_azure_openai_async_client = None
|
||||
|
||||
global_ollama_client = None
|
||||
|
||||
def get_openai_async_client_instance():
|
||||
global global_openai_async_client
|
||||
@@ -30,6 +30,12 @@ def get_azure_openai_async_client_instance():
|
||||
global_azure_openai_async_client = AsyncAzureOpenAI()
|
||||
return global_azure_openai_async_client
|
||||
|
||||
def get_ollama_async_client_instance():
|
||||
global global_ollama_client
|
||||
if global_ollama_client is None:
|
||||
#global_ollama_client = Client(base_url="http://localhost:11434") # Adjust base URL if necessary
|
||||
global_ollama_client = Client(base_url="http://10.0.1.12:11434") # Adjust base URL if necessary
|
||||
return global_ollama_client
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
@@ -164,6 +170,7 @@ async def azure_gpt_4o_mini_complete(
|
||||
)
|
||||
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
@@ -176,3 +183,73 @@ async def azure_openai_embedding(texts: list[str]) -> np.ndarray:
|
||||
model="text-embedding-3-small", input=texts, encoding_format="float"
|
||||
)
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
|
||||
async def ollama_complete_if_cache(
|
||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
# Initialize the Ollama client
|
||||
ollama_client = get_ollama_async_client_instance()
|
||||
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
# Send the request to Ollama
|
||||
response = await ollama_client.chat(
|
||||
model=model,
|
||||
messages=messages,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{args_hash: {"return": response.response, "model": model}}
|
||||
)
|
||||
await hashing_kv.index_done_callback()
|
||||
|
||||
return response.response
|
||||
|
||||
|
||||
async def ollama_complete(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
|
||||
return await ollama_complete_if_cache(
|
||||
"deepseek-r1:32b", # For now select your model
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def ollama_mini_complete(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
|
||||
return await ollama_complete_if_cache(
|
||||
"deepseek-r1:latest", # For now select your model
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def ollama_embedding(texts: list[str]) -> np.ndarray:
|
||||
# Initialize the Ollama client
|
||||
ollama_client = get_ollama_async_client_instance()
|
||||
|
||||
# Send the request to Ollama for embeddings
|
||||
response = await ollama_client.embeddings(
|
||||
model="nomic-embed-text", # Replace with the appropriate Ollama embedding model
|
||||
input=texts,
|
||||
encoding_format="float"
|
||||
)
|
||||
|
||||
# Extract embeddings from the response
|
||||
embeddings = [dp.embedding for dp in response.data]
|
||||
|
||||
return np.array(embeddings)
|
||||
|
||||
@@ -19,6 +19,8 @@ from ._llm import (
|
||||
azure_gpt_4o_complete,
|
||||
azure_openai_embedding,
|
||||
azure_gpt_4o_mini_complete,
|
||||
ollama_complete,
|
||||
ollama_embedding
|
||||
)
|
||||
from ._op import (
|
||||
chunking_by_video_segments,
|
||||
@@ -98,21 +100,60 @@ class VideoRAG:
|
||||
entity_extract_max_gleaning: int = 1
|
||||
entity_summary_to_max_tokens: int = 500
|
||||
|
||||
# text embedding
|
||||
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
||||
embedding_batch_num: int = 32
|
||||
embedding_func_max_async: int = 16
|
||||
query_better_than_threshold: float = 0.2
|
||||
# Uncomment as appropriate depending on whether you use openai, azure_openai or ollama
|
||||
|
||||
# LLM
|
||||
using_azure_openai: bool = False
|
||||
best_model_func: callable = gpt_4o_mini_complete
|
||||
best_model_max_token_size: int = 32768
|
||||
best_model_max_async: int = 16
|
||||
cheap_model_func: callable = gpt_4o_mini_complete
|
||||
cheap_model_max_token_size: int = 32768
|
||||
cheap_model_max_async: int = 16
|
||||
# Change to your LLM provider
|
||||
llm_provider = "ollama"
|
||||
if llm_provider == "openai":
|
||||
# text embedding
|
||||
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
||||
embedding_batch_num: int = 32
|
||||
embedding_func_max_async: int = 16
|
||||
query_better_than_threshold: float = 0.2
|
||||
|
||||
# LLM
|
||||
best_model_func: callable = gpt_4o_mini_complete
|
||||
best_model_max_token_size: int = 32768
|
||||
best_model_max_async: int = 16
|
||||
|
||||
cheap_model_func: callable = gpt_4o_mini_complete
|
||||
cheap_model_max_token_size: int = 32768
|
||||
cheap_model_max_async: int = 16
|
||||
if llm_provider == "azur_openai":
|
||||
# text embedding
|
||||
embedding_func = : EmbeddingFunc = field(default_factory=lambda: azure_openai_embedding)
|
||||
embedding_batch_num: int = 32
|
||||
embedding_func_max_async: int = 16
|
||||
query_better_than_threshold: float = 0.2
|
||||
|
||||
# LLM
|
||||
best_model_func: callable = azure_gpt_4o_complete
|
||||
best_model_max_token_size: int = 32768
|
||||
best_model_max_async: int = 16
|
||||
|
||||
cheap_model_func: callable = azure_gpt_4o_mini_complete
|
||||
cheap_model_max_token_size: int = 32768
|
||||
cheap_model_max_async: int = 16
|
||||
|
||||
if llm_provider == "ollama":
|
||||
# text embedding
|
||||
embedding_func: EmbeddingFunc = field(default_factory=lambda: ollama_embedding)
|
||||
embedding_batch_num: int = 32
|
||||
embedding_func_max_async: int = 1
|
||||
query_better_than_threshold: float = 0.2
|
||||
|
||||
# LLM
|
||||
best_model_func: callable = ollama_complete
|
||||
best_model_max_token_size: int = 32768
|
||||
best_model_max_async: int = 1
|
||||
|
||||
cheap_model_func: callable = ollama_mini_complete
|
||||
cheap_model_max_token_size: int = 32768
|
||||
cheap_model_max_async: int = 1
|
||||
|
||||
|
||||
|
||||
|
||||
# entity extraction
|
||||
entity_extraction_func: callable = extract_entities
|
||||
|
||||
@@ -143,18 +184,6 @@ class VideoRAG:
|
||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
||||
logger.debug(f"VideoRAG init with param:\n\n {_print_config}\n")
|
||||
|
||||
if self.using_azure_openai:
|
||||
# If there's no OpenAI API key, use Azure OpenAI
|
||||
if self.best_model_func == gpt_4o_complete:
|
||||
self.best_model_func = azure_gpt_4o_complete
|
||||
if self.cheap_model_func == gpt_4o_mini_complete:
|
||||
self.cheap_model_func = azure_gpt_4o_mini_complete
|
||||
if self.embedding_func == openai_embedding:
|
||||
self.embedding_func = azure_openai_embedding
|
||||
logger.info(
|
||||
"Switched the default openai funcs to Azure OpenAI if you didn't set any of it"
|
||||
)
|
||||
|
||||
if not os.path.exists(self.working_dir) and self.always_create_working_dir:
|
||||
logger.info(f"Creating working directory {self.working_dir}")
|
||||
os.makedirs(self.working_dir)
|
||||
|
||||
Reference in New Issue
Block a user