diff --git a/videorag/_llm.py b/videorag/_llm.py index a8ff0df..2b23d1f 100755 --- a/videorag/_llm.py +++ b/videorag/_llm.py @@ -101,7 +101,8 @@ async def openai_complete_if_cache( if hashing_kv is not None: args_hash = compute_args_hash(model, messages) if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: + # NOTE: I update here to avoid the if_cache_return["return"] is None + if if_cache_return is not None and if_cache_return["return"] is not None: return if_cache_return["return"] response = await openai_async_client.chat.completions.create( @@ -211,7 +212,8 @@ async def azure_openai_complete_if_cache( if hashing_kv is not None: args_hash = compute_args_hash(deployment_name, messages) if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: + # NOTE: I update here to avoid the if_cache_return["return"] is None + if if_cache_return is not None and if_cache_return["return"] is not None: return if_cache_return["return"] response = await azure_openai_client.chat.completions.create( @@ -309,7 +311,8 @@ async def ollama_complete_if_cache( if hashing_kv is not None: args_hash = compute_args_hash(model, messages) if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: + # NOTE: I update here to avoid the if_cache_return["return"] is None + if if_cache_return is not None and if_cache_return["return"] is not None: return if_cache_return["return"] # Send the request to Ollama diff --git a/videorag/_op.py b/videorag/_op.py index e752325..d0fd9aa 100755 --- a/videorag/_op.py +++ b/videorag/_op.py @@ -742,3 +742,185 @@ async def videorag_query( system_prompt=sys_prompt, ) return response + +async def videorag_query_multiple_choice( + query, + entities_vdb, + text_chunks_db, + chunks_vdb, + video_path_db, + video_segments, + video_segment_feature_vdb, + knowledge_graph_inst, + caption_model, + caption_tokenizer, + query_param: QueryParam, + global_config: dict, +) -> str: + """_summary_ + A copy of the videorag_query function with several updates for handling multiple-choice queries. + """ + use_model_func = global_config["llm"]["best_model_func"] + query = query + + # naive chunks + results = await chunks_vdb.query(query, top_k=query_param.top_k) + # NOTE: I update here, not len results can also process + if len(results): + chunks_ids = [r["id"] for r in results] + chunks = await text_chunks_db.get_by_ids(chunks_ids) + + maybe_trun_chunks = truncate_list_by_token_size( + chunks, + key=lambda x: x["content"], + max_token_size=query_param.naive_max_token_for_text_unit, + ) + logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks") + section = "-----New Chunk-----\n".join([c["content"] for c in maybe_trun_chunks]) + retreived_chunk_context = section + else: + retreived_chunk_context = "No Content" + + # visual retrieval + query_for_entity_retrieval = await _refine_entity_retrieval_query( + query, + query_param, + global_config, + ) + entity_results = await entities_vdb.query(query_for_entity_retrieval, top_k=query_param.top_k) + entity_retrieved_segments = set() + if len(entity_results): + node_datas = await asyncio.gather( + *[knowledge_graph_inst.get_node(r["entity_name"]) for r in entity_results] + ) + if not all([n is not None for n in node_datas]): + logger.warning("Some nodes are missing, maybe the storage is damaged") + node_degrees = await asyncio.gather( + *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in entity_results] + ) + node_datas = [ + {**n, "entity_name": k["entity_name"], "rank": d} + for k, n, d in zip(entity_results, node_datas, node_degrees) + if n is not None + ] + entity_retrieved_segments = entity_retrieved_segments.union(await _find_most_related_segments_from_entities( + global_config["retrieval_topk_chunks"], node_datas, text_chunks_db, knowledge_graph_inst + )) + + # visual retrieval + query_for_visual_retrieval = await _refine_visual_retrieval_query( + query, + query_param, + global_config, + ) + segment_results = await video_segment_feature_vdb.query(query_for_visual_retrieval) + visual_retrieved_segments = set() + if len(segment_results): + for n in segment_results: + visual_retrieved_segments.add(n['__id__']) + + # caption + retrieved_segments = list(entity_retrieved_segments.union(visual_retrieved_segments)) + retrieved_segments = sorted( + retrieved_segments, + key=lambda x: ( + '_'.join(x.split('_')[:-1]), # video_name + eval(x.split('_')[-1]) # index + ) + ) + print(query_for_entity_retrieval) + print(f"Retrieved Text Segments {entity_retrieved_segments}") + print(query_for_visual_retrieval) + print(f"Retrieved Visual Segments {visual_retrieved_segments}") + + already_processed = 0 + async def _filter_single_segment(knowledge: str, segment_key_dp: tuple[str, str]): + nonlocal use_model_func, already_processed + segment_key = segment_key_dp[0] + segment_content = segment_key_dp[1] + filter_prompt = PROMPTS["filtering_segment"] + filter_prompt = filter_prompt.format(caption=segment_content, knowledge=knowledge) + result = await use_model_func(filter_prompt) + already_processed += 1 + now_ticks = PROMPTS["process_tickers"][ + already_processed % len(PROMPTS["process_tickers"]) + ] + print( + f"{now_ticks} Checked {already_processed} segments\r", + end="", + flush=True, + ) + return (segment_key, result) + + rough_captions = {} + for s_id in retrieved_segments: + video_name = '_'.join(s_id.split('_')[:-1]) + index = s_id.split('_')[-1] + rough_captions[s_id] = video_segments._data[video_name][index]["content"] + results = await asyncio.gather( + *[_filter_single_segment(query, (s_id, rough_captions[s_id])) for s_id in rough_captions] + ) + remain_segments = [x[0] for x in results if 'yes' in x[1].lower()] + print(f"{len(remain_segments)} Video Segments remain after filtering") + if len(remain_segments) == 0: + print("Since no segments remain after filtering, we utilized all the retrieved segments.") + remain_segments = retrieved_segments + print(f"Remain segments {remain_segments}") + + # visual retrieval + keywords_for_caption = await _extract_keywords_query( + query, + query_param, + global_config, + ) + print(f"Keywords: {keywords_for_caption}") + caption_results = retrieved_segment_caption( + caption_model, + caption_tokenizer, + keywords_for_caption, + remain_segments, + video_path_db, + video_segments, + num_sampled_frames=global_config['fine_num_frames_per_segment'] + ) + + ## data table + text_units_section_list = [["video_name", "start_time", "end_time", "content"]] + for s_id in caption_results: + video_name = '_'.join(s_id.split('_')[:-1]) + index = s_id.split('_')[-1] + start_time = eval(video_segments._data[video_name][index]["time"].split('-')[0]) + end_time = eval(video_segments._data[video_name][index]["time"].split('-')[1]) + start_time = f"{start_time // 3600}:{(start_time % 3600) // 60}:{start_time % 60}" + end_time = f"{end_time // 3600}:{(end_time % 3600) // 60}:{end_time % 60}" + text_units_section_list.append([video_name, start_time, end_time, caption_results[s_id]]) + text_units_context = list_of_list_to_csv(text_units_section_list) + + retreived_video_context = f"\n-----Retrieved Knowledge From Videos-----\n```csv\n{text_units_context}\n```\n" + + # NOTE: I update here to use a different prompt + sys_prompt_temp = PROMPTS["videorag_response_for_multiple_choice_question"] + + sys_prompt = sys_prompt_temp.format( + video_data=retreived_video_context, + chunk_data=retreived_chunk_context, + response_type=query_param.response_type + ) + response = await use_model_func( + query, + system_prompt=sys_prompt, + use_cache=False, + ) + while True: + try: + json_response = json.loads(response) + assert "Answer" in json_response and "Explanation" in json_response + return json_response + except Exception as e: + logger.info(f"Response is not valid JSON for query {query}. Found {e}. Retrying...") + response = await use_model_func( + query, + system_prompt=sys_prompt, + use_cache=False, + ) + \ No newline at end of file diff --git a/videorag/prompt.py b/videorag/prompt.py index f65136a..283c440 100755 --- a/videorag/prompt.py +++ b/videorag/prompt.py @@ -400,4 +400,43 @@ Do not include information where the supporting evidence for it is not provided. ---Notice--- Please add sections and commentary as appropriate for the length and format if necessary. Format the response in Markdown. +""" + +PROMPTS[ + "videorag_response_for_multiple_choice_question" +] = """---Role--- + +You are a helpful assistant responding to a multiple-choice question with retrieved knowledge. + +---Goal--- + +Generate a concise response that addresses the user's question by summarizing relevant information derived from the retrieved text and video content. Ensure the response aligns with the specified format and length. +Please note that there is only one choice is correct. + +---Target response length and format--- + +{response_type} + +---Retrieved Information From Videos--- + +{video_data} + +---Retrieved Text Chunks--- + +{chunk_data} + +---Goal--- + +Generate a concise response that addresses the user's question by summarizing relevant information derived from the retrieved text and video content. Ensure the response aligns with the specified format and length. +Please note that there is only one choice is correct. + +---Notice--- +Please provide your answer in JSON format as follows: +{{ + "Answer": "The label of the answer, like A/B/C/D or 1/2/3/4 or others, depending on the given query" + "Explanation": "Provide explanations for your choice. Use sections and commentary as needed to ensure clarity and depth. Format the response in Markdown." +}} +Key points: +1. Ensure that the "Answer" reflects the correct label format. +2. Structure the "Explanation" for clarity, using Markdown for any necessary formatting. """ \ No newline at end of file diff --git a/videorag/videorag.py b/videorag/videorag.py index 72da296..0ceb6cd 100755 --- a/videorag/videorag.py +++ b/videorag/videorag.py @@ -23,6 +23,7 @@ from ._op import ( extract_entities, get_chunks, videorag_query, + videorag_query_multiple_choice, ) from ._storage import ( JsonKVStorage, @@ -319,6 +320,22 @@ class VideoRAG: param, asdict(self), ) + # NOTE: update here + elif param.mode == "videorag_multiple_choice": + response = await videorag_query_multiple_choice( + query, + self.entities_vdb, + self.text_chunks, + self.chunks_vdb, + self.video_path_db, + self.video_segments, + self.video_segment_feature_vdb, + self.chunk_entity_relation_graph, + self.caption_model, + self.caption_tokenizer, + param, + asdict(self), + ) else: raise ValueError(f"Unknown mode {param.mode}") await self._query_done()