multiple choice question

This commit is contained in:
Re-bin
2025-05-06 20:59:39 +08:00
parent 2d94a37e40
commit ed1a95ee9e
4 changed files with 244 additions and 3 deletions

View File

@@ -101,7 +101,8 @@ async def openai_complete_if_cache(
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
# NOTE: I update here to avoid the if_cache_return["return"] is None
if if_cache_return is not None and if_cache_return["return"] is not None:
return if_cache_return["return"]
response = await openai_async_client.chat.completions.create(
@@ -211,7 +212,8 @@ async def azure_openai_complete_if_cache(
if hashing_kv is not None:
args_hash = compute_args_hash(deployment_name, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
# NOTE: I update here to avoid the if_cache_return["return"] is None
if if_cache_return is not None and if_cache_return["return"] is not None:
return if_cache_return["return"]
response = await azure_openai_client.chat.completions.create(
@@ -309,7 +311,8 @@ async def ollama_complete_if_cache(
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
# NOTE: I update here to avoid the if_cache_return["return"] is None
if if_cache_return is not None and if_cache_return["return"] is not None:
return if_cache_return["return"]
# Send the request to Ollama

View File

@@ -742,3 +742,185 @@ async def videorag_query(
system_prompt=sys_prompt,
)
return response
async def videorag_query_multiple_choice(
query,
entities_vdb,
text_chunks_db,
chunks_vdb,
video_path_db,
video_segments,
video_segment_feature_vdb,
knowledge_graph_inst,
caption_model,
caption_tokenizer,
query_param: QueryParam,
global_config: dict,
) -> str:
"""_summary_
A copy of the videorag_query function with several updates for handling multiple-choice queries.
"""
use_model_func = global_config["llm"]["best_model_func"]
query = query
# naive chunks
results = await chunks_vdb.query(query, top_k=query_param.top_k)
# NOTE: I update here, not len results can also process
if len(results):
chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids)
maybe_trun_chunks = truncate_list_by_token_size(
chunks,
key=lambda x: x["content"],
max_token_size=query_param.naive_max_token_for_text_unit,
)
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
section = "-----New Chunk-----\n".join([c["content"] for c in maybe_trun_chunks])
retreived_chunk_context = section
else:
retreived_chunk_context = "No Content"
# visual retrieval
query_for_entity_retrieval = await _refine_entity_retrieval_query(
query,
query_param,
global_config,
)
entity_results = await entities_vdb.query(query_for_entity_retrieval, top_k=query_param.top_k)
entity_retrieved_segments = set()
if len(entity_results):
node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in entity_results]
)
if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged")
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in entity_results]
)
node_datas = [
{**n, "entity_name": k["entity_name"], "rank": d}
for k, n, d in zip(entity_results, node_datas, node_degrees)
if n is not None
]
entity_retrieved_segments = entity_retrieved_segments.union(await _find_most_related_segments_from_entities(
global_config["retrieval_topk_chunks"], node_datas, text_chunks_db, knowledge_graph_inst
))
# visual retrieval
query_for_visual_retrieval = await _refine_visual_retrieval_query(
query,
query_param,
global_config,
)
segment_results = await video_segment_feature_vdb.query(query_for_visual_retrieval)
visual_retrieved_segments = set()
if len(segment_results):
for n in segment_results:
visual_retrieved_segments.add(n['__id__'])
# caption
retrieved_segments = list(entity_retrieved_segments.union(visual_retrieved_segments))
retrieved_segments = sorted(
retrieved_segments,
key=lambda x: (
'_'.join(x.split('_')[:-1]), # video_name
eval(x.split('_')[-1]) # index
)
)
print(query_for_entity_retrieval)
print(f"Retrieved Text Segments {entity_retrieved_segments}")
print(query_for_visual_retrieval)
print(f"Retrieved Visual Segments {visual_retrieved_segments}")
already_processed = 0
async def _filter_single_segment(knowledge: str, segment_key_dp: tuple[str, str]):
nonlocal use_model_func, already_processed
segment_key = segment_key_dp[0]
segment_content = segment_key_dp[1]
filter_prompt = PROMPTS["filtering_segment"]
filter_prompt = filter_prompt.format(caption=segment_content, knowledge=knowledge)
result = await use_model_func(filter_prompt)
already_processed += 1
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
print(
f"{now_ticks} Checked {already_processed} segments\r",
end="",
flush=True,
)
return (segment_key, result)
rough_captions = {}
for s_id in retrieved_segments:
video_name = '_'.join(s_id.split('_')[:-1])
index = s_id.split('_')[-1]
rough_captions[s_id] = video_segments._data[video_name][index]["content"]
results = await asyncio.gather(
*[_filter_single_segment(query, (s_id, rough_captions[s_id])) for s_id in rough_captions]
)
remain_segments = [x[0] for x in results if 'yes' in x[1].lower()]
print(f"{len(remain_segments)} Video Segments remain after filtering")
if len(remain_segments) == 0:
print("Since no segments remain after filtering, we utilized all the retrieved segments.")
remain_segments = retrieved_segments
print(f"Remain segments {remain_segments}")
# visual retrieval
keywords_for_caption = await _extract_keywords_query(
query,
query_param,
global_config,
)
print(f"Keywords: {keywords_for_caption}")
caption_results = retrieved_segment_caption(
caption_model,
caption_tokenizer,
keywords_for_caption,
remain_segments,
video_path_db,
video_segments,
num_sampled_frames=global_config['fine_num_frames_per_segment']
)
## data table
text_units_section_list = [["video_name", "start_time", "end_time", "content"]]
for s_id in caption_results:
video_name = '_'.join(s_id.split('_')[:-1])
index = s_id.split('_')[-1]
start_time = eval(video_segments._data[video_name][index]["time"].split('-')[0])
end_time = eval(video_segments._data[video_name][index]["time"].split('-')[1])
start_time = f"{start_time // 3600}:{(start_time % 3600) // 60}:{start_time % 60}"
end_time = f"{end_time // 3600}:{(end_time % 3600) // 60}:{end_time % 60}"
text_units_section_list.append([video_name, start_time, end_time, caption_results[s_id]])
text_units_context = list_of_list_to_csv(text_units_section_list)
retreived_video_context = f"\n-----Retrieved Knowledge From Videos-----\n```csv\n{text_units_context}\n```\n"
# NOTE: I update here to use a different prompt
sys_prompt_temp = PROMPTS["videorag_response_for_multiple_choice_question"]
sys_prompt = sys_prompt_temp.format(
video_data=retreived_video_context,
chunk_data=retreived_chunk_context,
response_type=query_param.response_type
)
response = await use_model_func(
query,
system_prompt=sys_prompt,
use_cache=False,
)
while True:
try:
json_response = json.loads(response)
assert "Answer" in json_response and "Explanation" in json_response
return json_response
except Exception as e:
logger.info(f"Response is not valid JSON for query {query}. Found {e}. Retrying...")
response = await use_model_func(
query,
system_prompt=sys_prompt,
use_cache=False,
)

View File

@@ -400,4 +400,43 @@ Do not include information where the supporting evidence for it is not provided.
---Notice---
Please add sections and commentary as appropriate for the length and format if necessary. Format the response in Markdown.
"""
PROMPTS[
"videorag_response_for_multiple_choice_question"
] = """---Role---
You are a helpful assistant responding to a multiple-choice question with retrieved knowledge.
---Goal---
Generate a concise response that addresses the user's question by summarizing relevant information derived from the retrieved text and video content. Ensure the response aligns with the specified format and length.
Please note that there is only one choice is correct.
---Target response length and format---
{response_type}
---Retrieved Information From Videos---
{video_data}
---Retrieved Text Chunks---
{chunk_data}
---Goal---
Generate a concise response that addresses the user's question by summarizing relevant information derived from the retrieved text and video content. Ensure the response aligns with the specified format and length.
Please note that there is only one choice is correct.
---Notice---
Please provide your answer in JSON format as follows:
{{
"Answer": "The label of the answer, like A/B/C/D or 1/2/3/4 or others, depending on the given query"
"Explanation": "Provide explanations for your choice. Use sections and commentary as needed to ensure clarity and depth. Format the response in Markdown."
}}
Key points:
1. Ensure that the "Answer" reflects the correct label format.
2. Structure the "Explanation" for clarity, using Markdown for any necessary formatting.
"""

View File

@@ -23,6 +23,7 @@ from ._op import (
extract_entities,
get_chunks,
videorag_query,
videorag_query_multiple_choice,
)
from ._storage import (
JsonKVStorage,
@@ -319,6 +320,22 @@ class VideoRAG:
param,
asdict(self),
)
# NOTE: update here
elif param.mode == "videorag_multiple_choice":
response = await videorag_query_multiple_choice(
query,
self.entities_vdb,
self.text_chunks,
self.chunks_vdb,
self.video_path_db,
self.video_segments,
self.video_segment_feature_vdb,
self.chunk_entity_relation_graph,
self.caption_model,
self.caption_tokenizer,
param,
asdict(self),
)
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()