mirror of
https://github.com/HKUDS/VideoRAG.git
synced 2025-05-11 03:54:36 +03:00
multiple choice question
This commit is contained in:
@@ -101,7 +101,8 @@ async def openai_complete_if_cache(
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
# NOTE: I update here to avoid the if_cache_return["return"] is None
|
||||
if if_cache_return is not None and if_cache_return["return"] is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
response = await openai_async_client.chat.completions.create(
|
||||
@@ -211,7 +212,8 @@ async def azure_openai_complete_if_cache(
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(deployment_name, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
# NOTE: I update here to avoid the if_cache_return["return"] is None
|
||||
if if_cache_return is not None and if_cache_return["return"] is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
response = await azure_openai_client.chat.completions.create(
|
||||
@@ -309,7 +311,8 @@ async def ollama_complete_if_cache(
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
# NOTE: I update here to avoid the if_cache_return["return"] is None
|
||||
if if_cache_return is not None and if_cache_return["return"] is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
# Send the request to Ollama
|
||||
|
||||
182
videorag/_op.py
182
videorag/_op.py
@@ -742,3 +742,185 @@ async def videorag_query(
|
||||
system_prompt=sys_prompt,
|
||||
)
|
||||
return response
|
||||
|
||||
async def videorag_query_multiple_choice(
|
||||
query,
|
||||
entities_vdb,
|
||||
text_chunks_db,
|
||||
chunks_vdb,
|
||||
video_path_db,
|
||||
video_segments,
|
||||
video_segment_feature_vdb,
|
||||
knowledge_graph_inst,
|
||||
caption_model,
|
||||
caption_tokenizer,
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
) -> str:
|
||||
"""_summary_
|
||||
A copy of the videorag_query function with several updates for handling multiple-choice queries.
|
||||
"""
|
||||
use_model_func = global_config["llm"]["best_model_func"]
|
||||
query = query
|
||||
|
||||
# naive chunks
|
||||
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
||||
# NOTE: I update here, not len results can also process
|
||||
if len(results):
|
||||
chunks_ids = [r["id"] for r in results]
|
||||
chunks = await text_chunks_db.get_by_ids(chunks_ids)
|
||||
|
||||
maybe_trun_chunks = truncate_list_by_token_size(
|
||||
chunks,
|
||||
key=lambda x: x["content"],
|
||||
max_token_size=query_param.naive_max_token_for_text_unit,
|
||||
)
|
||||
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
|
||||
section = "-----New Chunk-----\n".join([c["content"] for c in maybe_trun_chunks])
|
||||
retreived_chunk_context = section
|
||||
else:
|
||||
retreived_chunk_context = "No Content"
|
||||
|
||||
# visual retrieval
|
||||
query_for_entity_retrieval = await _refine_entity_retrieval_query(
|
||||
query,
|
||||
query_param,
|
||||
global_config,
|
||||
)
|
||||
entity_results = await entities_vdb.query(query_for_entity_retrieval, top_k=query_param.top_k)
|
||||
entity_retrieved_segments = set()
|
||||
if len(entity_results):
|
||||
node_datas = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in entity_results]
|
||||
)
|
||||
if not all([n is not None for n in node_datas]):
|
||||
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
||||
node_degrees = await asyncio.gather(
|
||||
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in entity_results]
|
||||
)
|
||||
node_datas = [
|
||||
{**n, "entity_name": k["entity_name"], "rank": d}
|
||||
for k, n, d in zip(entity_results, node_datas, node_degrees)
|
||||
if n is not None
|
||||
]
|
||||
entity_retrieved_segments = entity_retrieved_segments.union(await _find_most_related_segments_from_entities(
|
||||
global_config["retrieval_topk_chunks"], node_datas, text_chunks_db, knowledge_graph_inst
|
||||
))
|
||||
|
||||
# visual retrieval
|
||||
query_for_visual_retrieval = await _refine_visual_retrieval_query(
|
||||
query,
|
||||
query_param,
|
||||
global_config,
|
||||
)
|
||||
segment_results = await video_segment_feature_vdb.query(query_for_visual_retrieval)
|
||||
visual_retrieved_segments = set()
|
||||
if len(segment_results):
|
||||
for n in segment_results:
|
||||
visual_retrieved_segments.add(n['__id__'])
|
||||
|
||||
# caption
|
||||
retrieved_segments = list(entity_retrieved_segments.union(visual_retrieved_segments))
|
||||
retrieved_segments = sorted(
|
||||
retrieved_segments,
|
||||
key=lambda x: (
|
||||
'_'.join(x.split('_')[:-1]), # video_name
|
||||
eval(x.split('_')[-1]) # index
|
||||
)
|
||||
)
|
||||
print(query_for_entity_retrieval)
|
||||
print(f"Retrieved Text Segments {entity_retrieved_segments}")
|
||||
print(query_for_visual_retrieval)
|
||||
print(f"Retrieved Visual Segments {visual_retrieved_segments}")
|
||||
|
||||
already_processed = 0
|
||||
async def _filter_single_segment(knowledge: str, segment_key_dp: tuple[str, str]):
|
||||
nonlocal use_model_func, already_processed
|
||||
segment_key = segment_key_dp[0]
|
||||
segment_content = segment_key_dp[1]
|
||||
filter_prompt = PROMPTS["filtering_segment"]
|
||||
filter_prompt = filter_prompt.format(caption=segment_content, knowledge=knowledge)
|
||||
result = await use_model_func(filter_prompt)
|
||||
already_processed += 1
|
||||
now_ticks = PROMPTS["process_tickers"][
|
||||
already_processed % len(PROMPTS["process_tickers"])
|
||||
]
|
||||
print(
|
||||
f"{now_ticks} Checked {already_processed} segments\r",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
return (segment_key, result)
|
||||
|
||||
rough_captions = {}
|
||||
for s_id in retrieved_segments:
|
||||
video_name = '_'.join(s_id.split('_')[:-1])
|
||||
index = s_id.split('_')[-1]
|
||||
rough_captions[s_id] = video_segments._data[video_name][index]["content"]
|
||||
results = await asyncio.gather(
|
||||
*[_filter_single_segment(query, (s_id, rough_captions[s_id])) for s_id in rough_captions]
|
||||
)
|
||||
remain_segments = [x[0] for x in results if 'yes' in x[1].lower()]
|
||||
print(f"{len(remain_segments)} Video Segments remain after filtering")
|
||||
if len(remain_segments) == 0:
|
||||
print("Since no segments remain after filtering, we utilized all the retrieved segments.")
|
||||
remain_segments = retrieved_segments
|
||||
print(f"Remain segments {remain_segments}")
|
||||
|
||||
# visual retrieval
|
||||
keywords_for_caption = await _extract_keywords_query(
|
||||
query,
|
||||
query_param,
|
||||
global_config,
|
||||
)
|
||||
print(f"Keywords: {keywords_for_caption}")
|
||||
caption_results = retrieved_segment_caption(
|
||||
caption_model,
|
||||
caption_tokenizer,
|
||||
keywords_for_caption,
|
||||
remain_segments,
|
||||
video_path_db,
|
||||
video_segments,
|
||||
num_sampled_frames=global_config['fine_num_frames_per_segment']
|
||||
)
|
||||
|
||||
## data table
|
||||
text_units_section_list = [["video_name", "start_time", "end_time", "content"]]
|
||||
for s_id in caption_results:
|
||||
video_name = '_'.join(s_id.split('_')[:-1])
|
||||
index = s_id.split('_')[-1]
|
||||
start_time = eval(video_segments._data[video_name][index]["time"].split('-')[0])
|
||||
end_time = eval(video_segments._data[video_name][index]["time"].split('-')[1])
|
||||
start_time = f"{start_time // 3600}:{(start_time % 3600) // 60}:{start_time % 60}"
|
||||
end_time = f"{end_time // 3600}:{(end_time % 3600) // 60}:{end_time % 60}"
|
||||
text_units_section_list.append([video_name, start_time, end_time, caption_results[s_id]])
|
||||
text_units_context = list_of_list_to_csv(text_units_section_list)
|
||||
|
||||
retreived_video_context = f"\n-----Retrieved Knowledge From Videos-----\n```csv\n{text_units_context}\n```\n"
|
||||
|
||||
# NOTE: I update here to use a different prompt
|
||||
sys_prompt_temp = PROMPTS["videorag_response_for_multiple_choice_question"]
|
||||
|
||||
sys_prompt = sys_prompt_temp.format(
|
||||
video_data=retreived_video_context,
|
||||
chunk_data=retreived_chunk_context,
|
||||
response_type=query_param.response_type
|
||||
)
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
use_cache=False,
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
json_response = json.loads(response)
|
||||
assert "Answer" in json_response and "Explanation" in json_response
|
||||
return json_response
|
||||
except Exception as e:
|
||||
logger.info(f"Response is not valid JSON for query {query}. Found {e}. Retrying...")
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
@@ -400,4 +400,43 @@ Do not include information where the supporting evidence for it is not provided.
|
||||
|
||||
---Notice---
|
||||
Please add sections and commentary as appropriate for the length and format if necessary. Format the response in Markdown.
|
||||
"""
|
||||
|
||||
PROMPTS[
|
||||
"videorag_response_for_multiple_choice_question"
|
||||
] = """---Role---
|
||||
|
||||
You are a helpful assistant responding to a multiple-choice question with retrieved knowledge.
|
||||
|
||||
---Goal---
|
||||
|
||||
Generate a concise response that addresses the user's question by summarizing relevant information derived from the retrieved text and video content. Ensure the response aligns with the specified format and length.
|
||||
Please note that there is only one choice is correct.
|
||||
|
||||
---Target response length and format---
|
||||
|
||||
{response_type}
|
||||
|
||||
---Retrieved Information From Videos---
|
||||
|
||||
{video_data}
|
||||
|
||||
---Retrieved Text Chunks---
|
||||
|
||||
{chunk_data}
|
||||
|
||||
---Goal---
|
||||
|
||||
Generate a concise response that addresses the user's question by summarizing relevant information derived from the retrieved text and video content. Ensure the response aligns with the specified format and length.
|
||||
Please note that there is only one choice is correct.
|
||||
|
||||
---Notice---
|
||||
Please provide your answer in JSON format as follows:
|
||||
{{
|
||||
"Answer": "The label of the answer, like A/B/C/D or 1/2/3/4 or others, depending on the given query"
|
||||
"Explanation": "Provide explanations for your choice. Use sections and commentary as needed to ensure clarity and depth. Format the response in Markdown."
|
||||
}}
|
||||
Key points:
|
||||
1. Ensure that the "Answer" reflects the correct label format.
|
||||
2. Structure the "Explanation" for clarity, using Markdown for any necessary formatting.
|
||||
"""
|
||||
@@ -23,6 +23,7 @@ from ._op import (
|
||||
extract_entities,
|
||||
get_chunks,
|
||||
videorag_query,
|
||||
videorag_query_multiple_choice,
|
||||
)
|
||||
from ._storage import (
|
||||
JsonKVStorage,
|
||||
@@ -319,6 +320,22 @@ class VideoRAG:
|
||||
param,
|
||||
asdict(self),
|
||||
)
|
||||
# NOTE: update here
|
||||
elif param.mode == "videorag_multiple_choice":
|
||||
response = await videorag_query_multiple_choice(
|
||||
query,
|
||||
self.entities_vdb,
|
||||
self.text_chunks,
|
||||
self.chunks_vdb,
|
||||
self.video_path_db,
|
||||
self.video_segments,
|
||||
self.video_segment_feature_vdb,
|
||||
self.chunk_entity_relation_graph,
|
||||
self.caption_model,
|
||||
self.caption_tokenizer,
|
||||
param,
|
||||
asdict(self),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown mode {param.mode}")
|
||||
await self._query_done()
|
||||
|
||||
Reference in New Issue
Block a user