Merge pull request #20 from yashshah035/main

correcting imports to minirag and missing files
This commit is contained in:
Tianyu Fan
2025-02-12 14:12:39 +08:00
committed by GitHub
9 changed files with 122 additions and 14 deletions

View File

@@ -6,7 +6,7 @@ import os
from minirag import MiniRAG, QueryParam
from minirag.llm import (
hf_model_complete,
hf_embedding,
hf_embed,
)
from minirag.utils import EmbeddingFunc
from transformers import AutoModel, AutoTokenizer
@@ -63,7 +63,7 @@ rag = MiniRAG(
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=1000,
func=lambda texts: hf_embedding(
func=lambda texts: hf_embed(
texts,
tokenizer=AutoTokenizer.from_pretrained(EMBEDDING_MODEL),
embed_model=AutoModel.from_pretrained(EMBEDDING_MODEL),

View File

@@ -1,8 +1,8 @@
from dataclasses import dataclass, field
from typing import TypedDict, Union, Literal, Generic, TypeVar
from enum import Enum
from typing import Any, TypedDict, Optional, Union, Literal, Generic, TypeVar
import os
import numpy as np
from .utils import EmbeddingFunc
TextChunkSchema = TypedDict(
@@ -138,3 +138,52 @@ class BaseGraphStorage(StorageNameSpace):
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
raise NotImplementedError("Node embedding is not used in minirag.")
class DocStatus(str, Enum):
"""Document processing status enum"""
PENDING = "pending"
PROCESSING = "processing"
PROCESSED = "processed"
FAILED = "failed"
@dataclass
class DocProcessingStatus:
"""Document processing status data structure"""
content: str
"""Original content of the document"""
content_summary: str
"""First 100 chars of document content, used for preview"""
content_length: int
"""Total length of document"""
status: DocStatus
"""Current processing status"""
created_at: str
"""ISO format timestamp when document was created"""
updated_at: str
"""ISO format timestamp when document was last updated"""
chunks_count: Optional[int] = None
"""Number of chunks after splitting, used for processing"""
error: Optional[str] = None
"""Error message if failed"""
metadata: dict[str, Any] = field(default_factory=dict)
"""Additional metadata"""
class DocStatusStorage(BaseKVStorage):
"""Base class for document status storage"""
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
raise NotImplementedError
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all failed documents"""
raise NotImplementedError
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
raise NotImplementedError

58
minirag/exceptions.py Normal file
View File

@@ -0,0 +1,58 @@
import httpx
from typing import Literal
class APIStatusError(Exception):
"""Raised when an API response has a status code of 4xx or 5xx."""
response: httpx.Response
status_code: int
request_id: str | None
def __init__(
self, message: str, *, response: httpx.Response, body: object | None
) -> None:
super().__init__(message, response.request, body=body)
self.response = response
self.status_code = response.status_code
self.request_id = response.headers.get("x-request-id")
class APIConnectionError(Exception):
def __init__(
self, *, message: str = "Connection error.", request: httpx.Request
) -> None:
super().__init__(message, request, body=None)
class BadRequestError(APIStatusError):
status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]
class AuthenticationError(APIStatusError):
status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride]
class PermissionDeniedError(APIStatusError):
status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride]
class NotFoundError(APIStatusError):
status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride]
class ConflictError(APIStatusError):
status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride]
class UnprocessableEntityError(APIStatusError):
status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride]
class RateLimitError(APIStatusError):
status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride]
class APITimeoutError(APIConnectionError):
def __init__(self, request: httpx.Request) -> None:
super().__init__(message="Request timed out.", request=request)

View File

@@ -52,13 +52,13 @@ import asyncio
import os
from dataclasses import dataclass
from lightrag.utils import (
from minirag.utils import (
logger,
load_json,
write_json,
)
from lightrag.base import (
from minirag.base import (
BaseKVStorage,
)

View File

@@ -52,13 +52,13 @@ import os
from dataclasses import dataclass
from typing import Union, Dict
from lightrag.utils import (
from minirag.utils import (
logger,
load_json,
write_json,
)
from lightrag.base import (
from minirag.base import (
DocStatus,
DocProcessingStatus,
DocStatusStorage,

View File

@@ -61,12 +61,12 @@ if not pm.is_installed("nano-vectordb"):
from nano_vectordb import NanoVectorDB
import time
from lightrag.utils import (
from minirag.utils import (
logger,
compute_mdhash_id,
)
from lightrag.base import (
from minirag.base import (
BaseVectorStorage,
)

View File

@@ -56,11 +56,11 @@ import networkx as nx
import numpy as np
from lightrag.utils import (
from minirag.utils import (
logger,
)
from lightrag.base import (
from minirag.base import (
BaseGraphStorage,
)

View File

@@ -60,12 +60,12 @@ from tenacity import (
wait_exponential,
retry_if_exception_type,
)
from lightrag.exceptions import (
from minirag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from lightrag.utils import (
from minirag.utils import (
locate_json_string_body_from_string,
)
import torch

View File

@@ -4,6 +4,7 @@ aiohttp
configparser
graspologic
json_repair
httpx
# database packages
networkx