Implement retry for LLMClient (#44)

* implement retry

* chore: Refactor tenacity retry logic and improve LLMClient error handling

* poetry

* remove unnecessary try
This commit is contained in:
Daniel Chalef
2024-08-26 12:53:16 -07:00
committed by GitHub
parent 895afc7be1
commit fc4bf3bde2
3 changed files with 40 additions and 2 deletions

View File

@@ -20,7 +20,9 @@ import logging
import typing
from abc import ABC, abstractmethod
import httpx
from diskcache import Cache
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
from ..prompts.models import Message
from .config import LLMConfig
@@ -31,6 +33,12 @@ DEFAULT_CACHE_DIR = './llm_cache'
logger = logging.getLogger(__name__)
def is_server_error(exception):
return (
isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
)
class LLMClient(ABC):
def __init__(self, config: LLMConfig | None, cache: bool = False):
if config is None:
@@ -47,6 +55,20 @@ class LLMClient(ABC):
def get_embedder(self) -> typing.Any:
pass
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception(is_server_error),
)
async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
try:
return await self._generate_response(messages)
except httpx.HTTPStatusError as e:
if not is_server_error(e):
raise Exception(f'LLM request error: {e}') from e
else:
raise
@abstractmethod
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
pass
@@ -66,7 +88,7 @@ class LLMClient(ABC):
logger.debug(f'Cache hit for {cache_key}')
return cached_response
response = await self._generate_response(messages)
response = await self._generate_response_with_retry(messages)
if self.cache_enabled:
self.cache_dir.set(cache_key, response)

17
poetry.lock generated
View File

@@ -3253,6 +3253,21 @@ mpmath = ">=1.1.0,<1.4"
[package.extras]
dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"]
[[package]]
name = "tenacity"
version = "9.0.0"
description = "Retry code until it succeeds"
optional = false
python-versions = ">=3.8"
files = [
{file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"},
{file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"},
]
[package.extras]
doc = ["reno", "sphinx"]
test = ["pytest", "tornado (>=4.5)", "typeguard"]
[[package]]
name = "terminado"
version = "0.18.1"
@@ -3743,4 +3758,4 @@ test = ["websockets"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "5b90bb6d58d36a2553f5410c418b179aa1c86b55078567c33aaa6fddf6a8c6c6"
content-hash = "001663dfc8078ad473675c994b15191db1f53a844e23f40ffa4a704379a61132"

View File

@@ -23,6 +23,7 @@ diskcache = "^5.6.3"
arrow = "^1.3.0"
openai = "^1.38.0"
anthropic = "^0.34.1"
tenacity = "^9.0.0"
[tool.poetry.dev-dependencies]
pytest = "^8.3.2"