Merge pull request #20 from yashshah035/main
correcting imports to minirag and missing files
This commit is contained in:
4
main.py
4
main.py
@@ -6,7 +6,7 @@ import os
|
||||
from minirag import MiniRAG, QueryParam
|
||||
from minirag.llm import (
|
||||
hf_model_complete,
|
||||
hf_embedding,
|
||||
hf_embed,
|
||||
)
|
||||
from minirag.utils import EmbeddingFunc
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
@@ -63,7 +63,7 @@ rag = MiniRAG(
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=384,
|
||||
max_token_size=1000,
|
||||
func=lambda texts: hf_embedding(
|
||||
func=lambda texts: hf_embed(
|
||||
texts,
|
||||
tokenizer=AutoTokenizer.from_pretrained(EMBEDDING_MODEL),
|
||||
embed_model=AutoModel.from_pretrained(EMBEDDING_MODEL),
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypedDict, Union, Literal, Generic, TypeVar
|
||||
from enum import Enum
|
||||
from typing import Any, TypedDict, Optional, Union, Literal, Generic, TypeVar
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from .utils import EmbeddingFunc
|
||||
|
||||
TextChunkSchema = TypedDict(
|
||||
@@ -138,3 +138,52 @@ class BaseGraphStorage(StorageNameSpace):
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
raise NotImplementedError("Node embedding is not used in minirag.")
|
||||
|
||||
|
||||
class DocStatus(str, Enum):
|
||||
"""Document processing status enum"""
|
||||
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
PROCESSED = "processed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocProcessingStatus:
|
||||
"""Document processing status data structure"""
|
||||
|
||||
content: str
|
||||
"""Original content of the document"""
|
||||
content_summary: str
|
||||
"""First 100 chars of document content, used for preview"""
|
||||
content_length: int
|
||||
"""Total length of document"""
|
||||
status: DocStatus
|
||||
"""Current processing status"""
|
||||
created_at: str
|
||||
"""ISO format timestamp when document was created"""
|
||||
updated_at: str
|
||||
"""ISO format timestamp when document was last updated"""
|
||||
chunks_count: Optional[int] = None
|
||||
"""Number of chunks after splitting, used for processing"""
|
||||
error: Optional[str] = None
|
||||
"""Error message if failed"""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
"""Additional metadata"""
|
||||
|
||||
|
||||
class DocStatusStorage(BaseKVStorage):
|
||||
"""Base class for document status storage"""
|
||||
|
||||
async def get_status_counts(self) -> dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all failed documents"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all pending documents"""
|
||||
raise NotImplementedError
|
||||
|
||||
58
minirag/exceptions.py
Normal file
58
minirag/exceptions.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import httpx
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class APIStatusError(Exception):
|
||||
"""Raised when an API response has a status code of 4xx or 5xx."""
|
||||
|
||||
response: httpx.Response
|
||||
status_code: int
|
||||
request_id: str | None
|
||||
|
||||
def __init__(
|
||||
self, message: str, *, response: httpx.Response, body: object | None
|
||||
) -> None:
|
||||
super().__init__(message, response.request, body=body)
|
||||
self.response = response
|
||||
self.status_code = response.status_code
|
||||
self.request_id = response.headers.get("x-request-id")
|
||||
|
||||
|
||||
class APIConnectionError(Exception):
|
||||
def __init__(
|
||||
self, *, message: str = "Connection error.", request: httpx.Request
|
||||
) -> None:
|
||||
super().__init__(message, request, body=None)
|
||||
|
||||
|
||||
class BadRequestError(APIStatusError):
|
||||
status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class AuthenticationError(APIStatusError):
|
||||
status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class PermissionDeniedError(APIStatusError):
|
||||
status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class NotFoundError(APIStatusError):
|
||||
status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class ConflictError(APIStatusError):
|
||||
status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class UnprocessableEntityError(APIStatusError):
|
||||
status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class RateLimitError(APIStatusError):
|
||||
status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class APITimeoutError(APIConnectionError):
|
||||
def __init__(self, request: httpx.Request) -> None:
|
||||
super().__init__(message="Request timed out.", request=request)
|
||||
@@ -52,13 +52,13 @@ import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lightrag.utils import (
|
||||
from minirag.utils import (
|
||||
logger,
|
||||
load_json,
|
||||
write_json,
|
||||
)
|
||||
|
||||
from lightrag.base import (
|
||||
from minirag.base import (
|
||||
BaseKVStorage,
|
||||
)
|
||||
|
||||
|
||||
@@ -52,13 +52,13 @@ import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Union, Dict
|
||||
|
||||
from lightrag.utils import (
|
||||
from minirag.utils import (
|
||||
logger,
|
||||
load_json,
|
||||
write_json,
|
||||
)
|
||||
|
||||
from lightrag.base import (
|
||||
from minirag.base import (
|
||||
DocStatus,
|
||||
DocProcessingStatus,
|
||||
DocStatusStorage,
|
||||
|
||||
@@ -61,12 +61,12 @@ if not pm.is_installed("nano-vectordb"):
|
||||
from nano_vectordb import NanoVectorDB
|
||||
import time
|
||||
|
||||
from lightrag.utils import (
|
||||
from minirag.utils import (
|
||||
logger,
|
||||
compute_mdhash_id,
|
||||
)
|
||||
|
||||
from lightrag.base import (
|
||||
from minirag.base import (
|
||||
BaseVectorStorage,
|
||||
)
|
||||
|
||||
|
||||
@@ -56,11 +56,11 @@ import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
|
||||
from lightrag.utils import (
|
||||
from minirag.utils import (
|
||||
logger,
|
||||
)
|
||||
|
||||
from lightrag.base import (
|
||||
from minirag.base import (
|
||||
BaseGraphStorage,
|
||||
)
|
||||
|
||||
|
||||
@@ -60,12 +60,12 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from lightrag.exceptions import (
|
||||
from minirag.exceptions import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
from lightrag.utils import (
|
||||
from minirag.utils import (
|
||||
locate_json_string_body_from_string,
|
||||
)
|
||||
import torch
|
||||
|
||||
@@ -4,6 +4,7 @@ aiohttp
|
||||
configparser
|
||||
graspologic
|
||||
json_repair
|
||||
httpx
|
||||
|
||||
# database packages
|
||||
networkx
|
||||
|
||||
Reference in New Issue
Block a user