mirror of
https://github.com/MadcowD/ell.git
synced 2024-09-22 16:14:36 +03:00
anthropic
This commit is contained in:
@@ -8,5 +8,5 @@ def hello_from_claude():
|
||||
|
||||
if __name__ == "__main__":
|
||||
ell.init(verbose=True, store="./logdir", autocommit=True)
|
||||
hello_from_claude()
|
||||
print(hello_from_claude())
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ class Provider(ABC):
|
||||
################################
|
||||
@abstractmethod
|
||||
def provider_call_function(
|
||||
self, api_call_params: Optional[Dict[str, Any]] = None
|
||||
self, client: Any, api_call_params: Optional[Dict[str, Any]] = None
|
||||
) -> Callable[..., Any]:
|
||||
"""
|
||||
Implement this method to return the function that makes the API call to the language model.
|
||||
@@ -75,8 +75,8 @@ class Provider(ABC):
|
||||
"""
|
||||
return frozenset({"messages", "tools", "model", "stream", "stream_options"})
|
||||
|
||||
def available_api_params(self, api_params: Optional[Dict[str, Any]] = None):
|
||||
params = _call_params(self.provider_call_function(api_params))
|
||||
def available_api_params(self, client: Any, api_params: Optional[Dict[str, Any]] = None):
|
||||
params = _call_params(self.provider_call_function(client, api_params))
|
||||
return frozenset(params.keys()) - self.disallowed_api_params()
|
||||
|
||||
################################
|
||||
@@ -116,7 +116,7 @@ class Provider(ABC):
|
||||
|
||||
final_api_call_params = self.translate_to_provider(ell_call)
|
||||
|
||||
call = self.provider_call_function(final_api_call_params)
|
||||
call = self.provider_call_function(ell_call.client, final_api_call_params)
|
||||
assert self.dangerous_disable_validation or _validate_provider_call_params(final_api_call_params, call)
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import ell.providers.openai
|
||||
# import ell.providers.anthropic
|
||||
import ell.providers.anthropic
|
||||
# import ell.providers.groq
|
||||
# import ell.providers.mistral
|
||||
# import ell.providers.cohere
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from ell.provider import APICallResult, Provider
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
from ell.provider import EllCallParams, Metadata, Provider
|
||||
from ell.types import Message, ContentBlock, ToolCall
|
||||
from ell.types._lstr import _lstr
|
||||
from ell.types.message import LMP
|
||||
@@ -12,21 +12,19 @@ import json
|
||||
try:
|
||||
import anthropic
|
||||
from anthropic import Anthropic
|
||||
from anthropic.types import Message as AnthropicMessage, MessageCreateParams, RawMessageStreamEvent
|
||||
from anthropic._streaming import Stream
|
||||
|
||||
class AnthropicProvider(Provider):
|
||||
@classmethod
|
||||
def call(
|
||||
cls,
|
||||
client: Anthropic,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
api_params: Dict[str, Any],
|
||||
tools: Optional[list[LMP]] = None,
|
||||
) -> APICallResult:
|
||||
final_call_params = api_params.copy()
|
||||
|
||||
def provider_call_function(self, client : Anthropic, api_call_params : Optional[Dict[str, Any]] = None) -> Callable[..., Any]:
|
||||
return client.messages.create
|
||||
|
||||
def translate_to_provider(self, ell_call : EllCallParams) -> Dict[str, Any]:
|
||||
final_call_params = ell_call.api_params.copy()
|
||||
assert final_call_params.get("max_tokens") is not None, f"max_tokens is required for anthropic calls, pass it to the @ell.simple/complex decorator, e.g. @ell.simple(..., max_tokens=your_max_tokens) or pass it to the model directly as a parameter when calling your LMP: your_lmp(..., api_params=({{'max_tokens': your_max_tokens}}))."
|
||||
|
||||
anthropic_messages = [message_to_anthropic_format(message) for message in messages]
|
||||
anthropic_messages = [message_to_anthropic_format(message) for message in ell_call.messages]
|
||||
system_message = None
|
||||
if anthropic_messages and anthropic_messages[0]["role"] == "system":
|
||||
system_message = anthropic_messages.pop(0)
|
||||
@@ -34,49 +32,43 @@ try:
|
||||
if system_message:
|
||||
final_call_params["system"] = system_message["content"][0]["text"]
|
||||
|
||||
actual_n = api_params.get("n", 1)
|
||||
final_call_params["model"] = model
|
||||
# XXX: untils streaming is implemented.
|
||||
final_call_params['stream'] = True
|
||||
|
||||
final_call_params["model"] = ell_call.model
|
||||
final_call_params["messages"] = anthropic_messages
|
||||
|
||||
if tools:
|
||||
if ell_call.tools:
|
||||
final_call_params["tools"] = [
|
||||
{
|
||||
"name": tool.__name__,
|
||||
"description": tool.__doc__,
|
||||
"input_schema": tool.__ell_params_model__.model_json_schema(),
|
||||
}
|
||||
for tool in tools
|
||||
for tool in ell_call.tools
|
||||
]
|
||||
|
||||
# Streaming unsupported.
|
||||
# XXX: Support soon.
|
||||
stream = True
|
||||
if stream:
|
||||
response = client.messages.stream(**final_call_params)
|
||||
else:
|
||||
response = client.messages.create(**final_call_params)
|
||||
|
||||
return APICallResult(
|
||||
response=response,
|
||||
actual_streaming=stream,
|
||||
actual_n=actual_n,
|
||||
final_call_params=final_call_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def process_response(
|
||||
cls, call_result: APICallResult, _invocation_origin: str, logger: Optional[Any] = None, tools: Optional[List[LMP]] = None,
|
||||
) -> Tuple[List[Message], Dict[str, Any]]:
|
||||
return final_call_params
|
||||
|
||||
def translate_from_provider(
|
||||
self,
|
||||
provider_response : Union[Stream[RawMessageStreamEvent], AnthropicMessage],
|
||||
ell_call: EllCallParams,
|
||||
provider_call_params: Dict[str, Any],
|
||||
origin_id: Optional[str] = None,
|
||||
logger: Optional[Callable[..., None]] = None,
|
||||
) -> Tuple[List[Message], Metadata]:
|
||||
|
||||
usage = {}
|
||||
tracked_results = []
|
||||
metadata = {}
|
||||
|
||||
if call_result.actual_streaming:
|
||||
if provider_call_params.get("stream", False):
|
||||
content = []
|
||||
current_block: Optional[Dict[str, Any]] = None
|
||||
message_metadata = {}
|
||||
|
||||
with call_result.response as stream:
|
||||
with cast(Stream[RawMessageStreamEvent], provider_response) as stream:
|
||||
for chunk in stream:
|
||||
if chunk.type == "message_start":
|
||||
message_metadata = chunk.message.dict()
|
||||
@@ -90,35 +82,26 @@ try:
|
||||
if current_block is not None:
|
||||
if current_block["type"] == "text":
|
||||
current_block["content"] += chunk.delta.text
|
||||
logger(chunk.delta.text)
|
||||
|
||||
|
||||
elif chunk.type == "content_block_stop":
|
||||
if current_block is not None:
|
||||
if current_block["type"] == "text":
|
||||
content.append(ContentBlock(text=_lstr(current_block["content"],origin_trace=_invocation_origin)))
|
||||
content.append(ContentBlock(text=_lstr(current_block["content"],origin_trace=origin_id)))
|
||||
elif current_block["type"] == "tool_use":
|
||||
try:
|
||||
final_cb = chunk.content_block
|
||||
matching_tool = next(
|
||||
(
|
||||
tool
|
||||
for tool in tools
|
||||
if tool.__name__ == final_cb.name
|
||||
),
|
||||
None,
|
||||
)
|
||||
matching_tool = ell_call.get_tool_by_name(final_cb.name)
|
||||
if matching_tool:
|
||||
params = matching_tool.__ell_params_model__(
|
||||
**final_cb.input
|
||||
)
|
||||
content.append(
|
||||
ContentBlock(
|
||||
tool_call=ToolCall(
|
||||
tool=matching_tool,
|
||||
tool_call_id=_lstr(
|
||||
final_cb.id,origin_trace=_invocation_origin
|
||||
final_cb.id,origin_trace=origin_id
|
||||
),
|
||||
params=params,
|
||||
params=final_cb.input,
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -139,35 +122,8 @@ try:
|
||||
elif chunk.type == "message_stop":
|
||||
tracked_results.append(Message(role="assistant", content=content))
|
||||
|
||||
if logger and current_block:
|
||||
if chunk.type == "text" and current_block["type"] == "text":
|
||||
logger(chunk.text)
|
||||
# print(chunk)
|
||||
|
||||
|
||||
metadata = message_metadata
|
||||
else:
|
||||
# Non-streaming response processing (unchanged)
|
||||
cbs = []
|
||||
for content_block in call_result.response.content:
|
||||
if content_block.type == "text":
|
||||
cbs.append(ContentBlock(text=_lstr(content_block.text,origin_trace=_invocation_origin)))
|
||||
elif content_block.type == "tool_use":
|
||||
assert tools is not None, "Tools were not provided to the model when calling it and yet anthropic returned a tool use."
|
||||
tool_call = ToolCall(
|
||||
tool=next((t for t in tools if t.__name__ == content_block.name), None) ,
|
||||
tool_call_id=content_block.id,
|
||||
params=content_block.input
|
||||
)
|
||||
cbs.append(ContentBlock(tool_call=tool_call))
|
||||
tracked_results.append(Message(role="assistant", content=cbs))
|
||||
if logger:
|
||||
logger(tracked_results[0].text)
|
||||
|
||||
|
||||
usage = call_result.response.usage.dict() if call_result.response.usage else {}
|
||||
metadata = call_result.response.model_dump()
|
||||
del metadata["content"]
|
||||
|
||||
# process metadata for ell
|
||||
# XXX: Unify an ell metadata format for ell studio.
|
||||
@@ -178,28 +134,19 @@ try:
|
||||
metadata["usage"] = usage
|
||||
return tracked_results, metadata
|
||||
|
||||
@classmethod
|
||||
def supports_streaming(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_client_type(cls) -> Type:
|
||||
return Anthropic
|
||||
|
||||
@staticmethod
|
||||
def serialize_image_for_anthropic(img):
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode()
|
||||
|
||||
register_provider(AnthropicProvider)
|
||||
register_provider(AnthropicProvider(), Anthropic)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def serialize_image_for_anthropic(img):
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode()
|
||||
|
||||
def content_block_to_anthropic_format(content_block: ContentBlock) -> Dict[str, Any]:
|
||||
if content_block.image:
|
||||
base64_image = AnthropicProvider.serialize_image_for_anthropic(content_block.image)
|
||||
base64_image = serialize_image_for_anthropic(content_block.image)
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
|
||||
@@ -18,11 +18,11 @@ try:
|
||||
class OpenAIProvider(Provider):
|
||||
dangerous_disable_validation = True
|
||||
|
||||
def provider_call_function(self, api_call_params : Optional[Dict[str, Any]] = None) -> Callable[..., Any]:
|
||||
def provider_call_function(self, client : openai.Client, api_call_params : Optional[Dict[str, Any]] = None) -> Callable[..., Any]:
|
||||
if api_call_params and api_call_params.get("response_format"):
|
||||
return openai.beta.chat.completions.parse
|
||||
return client.beta.chat.completions.parse
|
||||
else:
|
||||
return openai.chat.completions.create
|
||||
return client.chat.completions.create
|
||||
|
||||
def translate_to_provider(self, ell_call : EllCallParams) -> Dict[str, Any]:
|
||||
final_call_params = ell_call.api_params.copy()
|
||||
|
||||
Reference in New Issue
Block a user