Files
office_document_test/raganything/query.py
2025-08-12 19:17:27 +08:00

747 lines
27 KiB
Python

"""
Query functionality for RAGAnything
Contains all query-related methods for both text and multimodal queries
"""
import json
import hashlib
import re
from typing import Dict, List, Any
from pathlib import Path
from lightrag import QueryParam
from lightrag.utils import always_get_an_event_loop
from raganything.prompt import PROMPTS
from raganything.utils import (
get_processor_for_type,
encode_image_to_base64,
validate_image_file,
)
class QueryMixin:
"""QueryMixin class containing query functionality for RAGAnything"""
def _generate_multimodal_cache_key(
self, query: str, multimodal_content: List[Dict[str, Any]], mode: str, **kwargs
) -> str:
"""
Generate cache key for multimodal query
Args:
query: Base query text
multimodal_content: List of multimodal content
mode: Query mode
**kwargs: Additional parameters
Returns:
str: Cache key hash
"""
# Create a normalized representation of the query parameters
cache_data = {
"query": query.strip(),
"mode": mode,
}
# Normalize multimodal content for stable caching
normalized_content = []
if multimodal_content:
for item in multimodal_content:
if isinstance(item, dict):
normalized_item = {}
for key, value in item.items():
# For file paths, use basename to make cache more portable
if key in [
"img_path",
"image_path",
"file_path",
] and isinstance(value, str):
normalized_item[key] = Path(value).name
# For large content, create a hash instead of storing directly
elif (
key in ["table_data", "table_body"]
and isinstance(value, str)
and len(value) > 200
):
normalized_item[f"{key}_hash"] = hashlib.md5(
value.encode()
).hexdigest()
else:
normalized_item[key] = value
normalized_content.append(normalized_item)
else:
normalized_content.append(item)
cache_data["multimodal_content"] = normalized_content
# Add relevant kwargs to cache data
relevant_kwargs = {
k: v
for k, v in kwargs.items()
if k
in [
"stream",
"response_type",
"top_k",
"max_tokens",
"temperature",
# "only_need_context",
# "only_need_prompt",
]
}
cache_data.update(relevant_kwargs)
# Generate hash from the cache data
cache_str = json.dumps(cache_data, sort_keys=True, ensure_ascii=False)
cache_hash = hashlib.md5(cache_str.encode()).hexdigest()
return f"multimodal_query:{cache_hash}"
async def aquery(self, query: str, mode: str = "mix", **kwargs) -> str:
"""
Pure text query - directly calls LightRAG's query functionality
Args:
query: Query text
mode: Query mode ("local", "global", "hybrid", "naive", "mix", "bypass")
**kwargs: Other query parameters, will be passed to QueryParam
- vlm_enhanced: bool, default True when vision_model_func is available.
If True, will parse image paths in retrieved context and replace them
with base64 encoded images for VLM processing.
Returns:
str: Query result
"""
if self.lightrag is None:
raise ValueError(
"No LightRAG instance available. Please process documents first or provide a pre-initialized LightRAG instance."
)
# Check if VLM enhanced query should be used
vlm_enhanced = kwargs.pop("vlm_enhanced", None)
# Auto-determine VLM enhanced based on availability
if vlm_enhanced is None:
vlm_enhanced = (
hasattr(self, "vision_model_func")
and self.vision_model_func is not None
)
# Use VLM enhanced query if enabled and available
if (
vlm_enhanced
and hasattr(self, "vision_model_func")
and self.vision_model_func
):
return await self.aquery_vlm_enhanced(query, mode=mode, **kwargs)
elif vlm_enhanced and (
not hasattr(self, "vision_model_func") or not self.vision_model_func
):
self.logger.warning(
"VLM enhanced query requested but vision_model_func is not available, falling back to normal query"
)
# Create query parameters
query_param = QueryParam(mode=mode, **kwargs)
self.logger.info(f"Executing text query: {query[:100]}...")
self.logger.info(f"Query mode: {mode}")
# Call LightRAG's query method
result = await self.lightrag.aquery(query, param=query_param)
self.logger.info("Text query completed")
return result
async def aquery_with_multimodal(
self,
query: str,
multimodal_content: List[Dict[str, Any]] = None,
mode: str = "mix",
**kwargs,
) -> str:
"""
Multimodal query - combines text and multimodal content for querying
Args:
query: Base query text
multimodal_content: List of multimodal content, each element contains:
- type: Content type ("image", "table", "equation", etc.)
- Other fields depend on type (e.g., img_path, table_data, latex, etc.)
mode: Query mode ("local", "global", "hybrid", "naive", "mix", "bypass")
**kwargs: Other query parameters, will be passed to QueryParam
Returns:
str: Query result
Examples:
# Pure text query
result = await rag.query_with_multimodal("What is machine learning?")
# Image query
result = await rag.query_with_multimodal(
"Analyze the content in this image",
multimodal_content=[{
"type": "image",
"img_path": "./image.jpg"
}]
)
# Table query
result = await rag.query_with_multimodal(
"Analyze the data trends in this table",
multimodal_content=[{
"type": "table",
"table_data": "Name,Age\nAlice,25\nBob,30"
}]
)
"""
# Ensure LightRAG is initialized
await self._ensure_lightrag_initialized()
self.logger.info(f"Executing multimodal query: {query[:100]}...")
self.logger.info(f"Query mode: {mode}")
# If no multimodal content, fallback to pure text query
if not multimodal_content:
self.logger.info("No multimodal content provided, executing text query")
return await self.aquery(query, mode=mode, **kwargs)
# Generate cache key for multimodal query
cache_key = self._generate_multimodal_cache_key(
query, multimodal_content, mode, **kwargs
)
# Check cache if available and enabled
cached_result = None
if (
hasattr(self, "lightrag")
and self.lightrag
and hasattr(self.lightrag, "llm_response_cache")
and self.lightrag.llm_response_cache
):
if self.lightrag.llm_response_cache.global_config.get(
"enable_llm_cache", True
):
try:
cached_result = await self.lightrag.llm_response_cache.get_by_id(
cache_key
)
if cached_result and isinstance(cached_result, dict):
result_content = cached_result.get("return")
if result_content:
self.logger.info(
f"Multimodal query cache hit: {cache_key[:16]}..."
)
return result_content
except Exception as e:
self.logger.debug(f"Error accessing multimodal query cache: {e}")
# Process multimodal content to generate enhanced query text
enhanced_query = await self._process_multimodal_query_content(
query, multimodal_content
)
self.logger.info(
f"Generated enhanced query length: {len(enhanced_query)} characters"
)
# Execute enhanced query
result = await self.aquery(enhanced_query, mode=mode, **kwargs)
# Save to cache if available and enabled
if (
hasattr(self, "lightrag")
and self.lightrag
and hasattr(self.lightrag, "llm_response_cache")
and self.lightrag.llm_response_cache
):
if self.lightrag.llm_response_cache.global_config.get(
"enable_llm_cache", True
):
try:
# Create cache entry for multimodal query
cache_entry = {
"return": result,
"cache_type": "multimodal_query",
"original_query": query,
"multimodal_content_count": len(multimodal_content),
"mode": mode,
}
await self.lightrag.llm_response_cache.upsert(
{cache_key: cache_entry}
)
self.logger.info(
f"Saved multimodal query result to cache: {cache_key[:16]}..."
)
except Exception as e:
self.logger.debug(f"Error saving multimodal query to cache: {e}")
# Ensure cache is persisted to disk
if (
hasattr(self, "lightrag")
and self.lightrag
and hasattr(self.lightrag, "llm_response_cache")
and self.lightrag.llm_response_cache
):
try:
await self.lightrag.llm_response_cache.index_done_callback()
except Exception as e:
self.logger.debug(f"Error persisting multimodal query cache: {e}")
self.logger.info("Multimodal query completed")
return result
async def aquery_vlm_enhanced(self, query: str, mode: str = "mix", **kwargs) -> str:
"""
VLM enhanced query - replaces image paths in retrieved context with base64 encoded images for VLM processing
Args:
query: User query
mode: Underlying LightRAG query mode
**kwargs: Other query parameters
Returns:
str: VLM query result
"""
# Ensure VLM is available
if not hasattr(self, "vision_model_func") or not self.vision_model_func:
raise ValueError(
"VLM enhanced query requires vision_model_func. "
"Please provide a vision model function when initializing RAGAnything."
)
# Ensure LightRAG is initialized
await self._ensure_lightrag_initialized()
self.logger.info(f"Executing VLM enhanced query: {query[:100]}...")
# Clear previous image cache
if hasattr(self, "_current_images_base64"):
delattr(self, "_current_images_base64")
# 1. Get original retrieval prompt (without generating final answer)
query_param = QueryParam(mode=mode, only_need_prompt=True, **kwargs)
raw_prompt = await self.lightrag.aquery(query, param=query_param)
self.logger.debug("Retrieved raw prompt from LightRAG")
# 2. Extract and process image paths
enhanced_prompt, images_found = await self._process_image_paths_for_vlm(
raw_prompt
)
if not images_found:
self.logger.info("No valid images found, falling back to normal query")
# Fallback to normal query
query_param = QueryParam(mode=mode, **kwargs)
return await self.lightrag.aquery(query, param=query_param)
self.logger.info(f"Processed {images_found} images for VLM")
# 3. Build VLM message format
messages = self._build_vlm_messages_with_images(enhanced_prompt, query)
# 4. Call VLM for question answering
result = await self._call_vlm_with_multimodal_content(messages)
self.logger.info("VLM enhanced query completed")
return result
async def _process_multimodal_query_content(
self, base_query: str, multimodal_content: List[Dict[str, Any]]
) -> str:
"""
Process multimodal query content to generate enhanced query text
Args:
base_query: Base query text
multimodal_content: List of multimodal content
Returns:
str: Enhanced query text
"""
self.logger.info("Starting multimodal query content processing...")
enhanced_parts = [f"User query: {base_query}"]
for i, content in enumerate(multimodal_content):
content_type = content.get("type", "unknown")
self.logger.info(
f"Processing {i+1}/{len(multimodal_content)} multimodal content: {content_type}"
)
try:
# Get appropriate processor
processor = get_processor_for_type(self.modal_processors, content_type)
if processor:
# Generate content description
description = await self._generate_query_content_description(
processor, content, content_type
)
enhanced_parts.append(
f"\nRelated {content_type} content: {description}"
)
else:
# If no appropriate processor, use basic description
basic_desc = str(content)[:200]
enhanced_parts.append(
f"\nRelated {content_type} content: {basic_desc}"
)
except Exception as e:
self.logger.error(f"Error processing multimodal content: {str(e)}")
# Continue processing other content
continue
enhanced_query = "\n".join(enhanced_parts)
enhanced_query += PROMPTS["QUERY_ENHANCEMENT_SUFFIX"]
self.logger.info("Multimodal query content processing completed")
return enhanced_query
async def _generate_query_content_description(
self, processor, content: Dict[str, Any], content_type: str
) -> str:
"""
Generate content description for query
Args:
processor: Multimodal processor
content: Content data
content_type: Content type
Returns:
str: Content description
"""
try:
if content_type == "image":
return await self._describe_image_for_query(processor, content)
elif content_type == "table":
return await self._describe_table_for_query(processor, content)
elif content_type == "equation":
return await self._describe_equation_for_query(processor, content)
else:
return await self._describe_generic_for_query(
processor, content, content_type
)
except Exception as e:
self.logger.error(f"Error generating {content_type} description: {str(e)}")
return f"{content_type} content: {str(content)[:100]}"
async def _describe_image_for_query(
self, processor, content: Dict[str, Any]
) -> str:
"""Generate image description for query"""
image_path = content.get("img_path")
captions = content.get("img_caption", [])
footnotes = content.get("img_footnote", [])
if image_path and Path(image_path).exists():
# If image exists, use vision model to generate description
image_base64 = processor._encode_image_to_base64(image_path)
if image_base64:
prompt = PROMPTS["QUERY_IMAGE_DESCRIPTION"]
description = await processor.modal_caption_func(
prompt,
image_data=image_base64,
system_prompt=PROMPTS["QUERY_IMAGE_ANALYST_SYSTEM"],
)
return description
# If image doesn't exist or processing failed, use existing information
parts = []
if image_path:
parts.append(f"Image path: {image_path}")
if captions:
parts.append(f"Image captions: {', '.join(captions)}")
if footnotes:
parts.append(f"Image footnotes: {', '.join(footnotes)}")
return "; ".join(parts) if parts else "Image content information incomplete"
async def _describe_table_for_query(
self, processor, content: Dict[str, Any]
) -> str:
"""Generate table description for query"""
table_data = content.get("table_data", "")
table_caption = content.get("table_caption", "")
prompt = PROMPTS["QUERY_TABLE_ANALYSIS"].format(
table_data=table_data, table_caption=table_caption
)
description = await processor.modal_caption_func(
prompt, system_prompt=PROMPTS["QUERY_TABLE_ANALYST_SYSTEM"]
)
return description
async def _describe_equation_for_query(
self, processor, content: Dict[str, Any]
) -> str:
"""Generate equation description for query"""
latex = content.get("latex", "")
equation_caption = content.get("equation_caption", "")
prompt = PROMPTS["QUERY_EQUATION_ANALYSIS"].format(
latex=latex, equation_caption=equation_caption
)
description = await processor.modal_caption_func(
prompt, system_prompt=PROMPTS["QUERY_EQUATION_ANALYST_SYSTEM"]
)
return description
async def _describe_generic_for_query(
self, processor, content: Dict[str, Any], content_type: str
) -> str:
"""Generate generic content description for query"""
content_str = str(content)
prompt = PROMPTS["QUERY_GENERIC_ANALYSIS"].format(
content_type=content_type, content_str=content_str
)
description = await processor.modal_caption_func(
prompt,
system_prompt=PROMPTS["QUERY_GENERIC_ANALYST_SYSTEM"].format(
content_type=content_type
),
)
return description
async def _process_image_paths_for_vlm(self, prompt: str) -> tuple[str, int]:
"""
Process image paths in prompt, keeping original paths and adding VLM markers
Args:
prompt: Original prompt
Returns:
tuple: (processed prompt, image count)
"""
enhanced_prompt = prompt
images_processed = 0
# Initialize image cache
self._current_images_base64 = []
# Enhanced regex pattern for matching image paths
# Matches only the path ending with image file extensions
image_path_pattern = (
r"Image Path:\s*([^\r\n]*?\.(?:jpg|jpeg|png|gif|bmp|webp|tiff|tif))"
)
# First, let's see what matches we find
matches = re.findall(image_path_pattern, prompt)
self.logger.info(f"Found {len(matches)} image path matches in prompt")
def replace_image_path(match):
nonlocal images_processed
image_path = match.group(1).strip()
self.logger.debug(f"Processing image path: '{image_path}'")
# Validate path format (basic check)
if not image_path or len(image_path) < 3:
self.logger.warning(f"Invalid image path format: {image_path}")
return match.group(0) # Keep original
# Use utility function to validate image file
self.logger.debug(f"Calling validate_image_file for: {image_path}")
is_valid = validate_image_file(image_path)
self.logger.debug(f"Validation result for {image_path}: {is_valid}")
if not is_valid:
self.logger.warning(f"Image validation failed for: {image_path}")
return match.group(0) # Keep original if validation fails
try:
# Encode image to base64 using utility function
self.logger.debug(f"Attempting to encode image: {image_path}")
image_base64 = encode_image_to_base64(image_path)
if image_base64:
images_processed += 1
# Save base64 to instance variable for later use
self._current_images_base64.append(image_base64)
# Keep original path info and add VLM marker
result = f"Image Path: {image_path}\n[VLM_IMAGE_{images_processed}]"
self.logger.debug(
f"Successfully processed image {images_processed}: {image_path}"
)
return result
else:
self.logger.error(f"Failed to encode image: {image_path}")
return match.group(0) # Keep original if encoding failed
except Exception as e:
self.logger.error(f"Failed to process image {image_path}: {e}")
return match.group(0) # Keep original
# Execute replacement
enhanced_prompt = re.sub(
image_path_pattern, replace_image_path, enhanced_prompt
)
return enhanced_prompt, images_processed
def _build_vlm_messages_with_images(
self, enhanced_prompt: str, user_query: str
) -> List[Dict]:
"""
Build VLM message format, using markers to correspond images with text positions
Args:
enhanced_prompt: Enhanced prompt with image markers
user_query: User query
Returns:
List[Dict]: VLM message format
"""
images_base64 = getattr(self, "_current_images_base64", [])
if not images_base64:
# Pure text mode
return [
{
"role": "user",
"content": f"Context:\n{enhanced_prompt}\n\nUser Question: {user_query}",
}
]
# Build multimodal content
content_parts = []
# Split text at image markers and insert images
text_parts = enhanced_prompt.split("[VLM_IMAGE_")
for i, text_part in enumerate(text_parts):
if i == 0:
# First text part
if text_part.strip():
content_parts.append({"type": "text", "text": text_part})
else:
# Find marker number and insert corresponding image
marker_match = re.match(r"(\d+)\](.*)", text_part, re.DOTALL)
if marker_match:
image_num = (
int(marker_match.group(1)) - 1
) # Convert to 0-based index
remaining_text = marker_match.group(2)
# Insert corresponding image
if 0 <= image_num < len(images_base64):
content_parts.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{images_base64[image_num]}"
},
}
)
# Insert remaining text
if remaining_text.strip():
content_parts.append({"type": "text", "text": remaining_text})
# Add user question
content_parts.append(
{
"type": "text",
"text": f"\n\nUser Question: {user_query}\n\nPlease answer based on the context and images provided.",
}
)
return [
{
"role": "system",
"content": "You are a helpful assistant that can analyze both text and image content to provide comprehensive answers.",
},
{"role": "user", "content": content_parts},
]
async def _call_vlm_with_multimodal_content(self, messages: List[Dict]) -> str:
"""
Call VLM to process multimodal content
Args:
messages: VLM message format
Returns:
str: VLM response result
"""
try:
user_message = messages[1]
content = user_message["content"]
system_prompt = messages[0]["content"]
if isinstance(content, str):
# Pure text mode
result = await self.vision_model_func(
content, system_prompt=system_prompt
)
else:
# Multimodal mode - pass complete messages directly to VLM
result = await self.vision_model_func(
"", # Empty prompt since we're using messages format
messages=messages,
)
return result
except Exception as e:
self.logger.error(f"VLM call failed: {e}")
raise
# Synchronous versions of query methods
def query(self, query: str, mode: str = "mix", **kwargs) -> str:
"""
Synchronous version of pure text query
Args:
query: Query text
mode: Query mode ("local", "global", "hybrid", "naive", "mix", "bypass")
**kwargs: Other query parameters, will be passed to QueryParam
- vlm_enhanced: bool, default True when vision_model_func is available.
If True, will parse image paths in retrieved context and replace them
with base64 encoded images for VLM processing.
Returns:
str: Query result
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, mode=mode, **kwargs))
def query_with_multimodal(
self,
query: str,
multimodal_content: List[Dict[str, Any]] = None,
mode: str = "mix",
**kwargs,
) -> str:
"""
Synchronous version of multimodal query
Args:
query: Base query text
multimodal_content: List of multimodal content, each element contains:
- type: Content type ("image", "table", "equation", etc.)
- Other fields depend on type (e.g., img_path, table_data, latex, etc.)
mode: Query mode ("local", "global", "hybrid", "naive", "mix", "bypass")
**kwargs: Other query parameters, will be passed to QueryParam
Returns:
str: Query result
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(
self.aquery_with_multimodal(query, multimodal_content, mode=mode, **kwargs)
)