anthropic

This commit is contained in:
William Guss
2024-09-20 18:25:34 -07:00
parent c9a293d910
commit 63ce6e4581
5 changed files with 50 additions and 103 deletions

View File

@@ -8,5 +8,5 @@ def hello_from_claude():
if __name__ == "__main__":
ell.init(verbose=True, store="./logdir", autocommit=True)
hello_from_claude()
print(hello_from_claude())

View File

@@ -61,7 +61,7 @@ class Provider(ABC):
################################
@abstractmethod
def provider_call_function(
self, api_call_params: Optional[Dict[str, Any]] = None
self, client: Any, api_call_params: Optional[Dict[str, Any]] = None
) -> Callable[..., Any]:
"""
Implement this method to return the function that makes the API call to the language model.
@@ -75,8 +75,8 @@ class Provider(ABC):
"""
return frozenset({"messages", "tools", "model", "stream", "stream_options"})
def available_api_params(self, api_params: Optional[Dict[str, Any]] = None):
params = _call_params(self.provider_call_function(api_params))
def available_api_params(self, client: Any, api_params: Optional[Dict[str, Any]] = None):
params = _call_params(self.provider_call_function(client, api_params))
return frozenset(params.keys()) - self.disallowed_api_params()
################################
@@ -116,7 +116,7 @@ class Provider(ABC):
final_api_call_params = self.translate_to_provider(ell_call)
call = self.provider_call_function(final_api_call_params)
call = self.provider_call_function(ell_call.client, final_api_call_params)
assert self.dangerous_disable_validation or _validate_provider_call_params(final_api_call_params, call)

View File

@@ -1,5 +1,5 @@
import ell.providers.openai
# import ell.providers.anthropic
import ell.providers.anthropic
# import ell.providers.groq
# import ell.providers.mistral
# import ell.providers.cohere

View File

@@ -1,5 +1,5 @@
from typing import Any, Dict, List, Optional, Tuple, Type
from ell.provider import APICallResult, Provider
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
from ell.provider import EllCallParams, Metadata, Provider
from ell.types import Message, ContentBlock, ToolCall
from ell.types._lstr import _lstr
from ell.types.message import LMP
@@ -12,21 +12,19 @@ import json
try:
import anthropic
from anthropic import Anthropic
from anthropic.types import Message as AnthropicMessage, MessageCreateParams, RawMessageStreamEvent
from anthropic._streaming import Stream
class AnthropicProvider(Provider):
@classmethod
def call(
cls,
client: Anthropic,
model: str,
messages: List[Message],
api_params: Dict[str, Any],
tools: Optional[list[LMP]] = None,
) -> APICallResult:
final_call_params = api_params.copy()
def provider_call_function(self, client : Anthropic, api_call_params : Optional[Dict[str, Any]] = None) -> Callable[..., Any]:
return client.messages.create
def translate_to_provider(self, ell_call : EllCallParams) -> Dict[str, Any]:
final_call_params = ell_call.api_params.copy()
assert final_call_params.get("max_tokens") is not None, f"max_tokens is required for anthropic calls, pass it to the @ell.simple/complex decorator, e.g. @ell.simple(..., max_tokens=your_max_tokens) or pass it to the model directly as a parameter when calling your LMP: your_lmp(..., api_params=({{'max_tokens': your_max_tokens}}))."
anthropic_messages = [message_to_anthropic_format(message) for message in messages]
anthropic_messages = [message_to_anthropic_format(message) for message in ell_call.messages]
system_message = None
if anthropic_messages and anthropic_messages[0]["role"] == "system":
system_message = anthropic_messages.pop(0)
@@ -34,49 +32,43 @@ try:
if system_message:
final_call_params["system"] = system_message["content"][0]["text"]
actual_n = api_params.get("n", 1)
final_call_params["model"] = model
# XXX: untils streaming is implemented.
final_call_params['stream'] = True
final_call_params["model"] = ell_call.model
final_call_params["messages"] = anthropic_messages
if tools:
if ell_call.tools:
final_call_params["tools"] = [
{
"name": tool.__name__,
"description": tool.__doc__,
"input_schema": tool.__ell_params_model__.model_json_schema(),
}
for tool in tools
for tool in ell_call.tools
]
# Streaming unsupported.
# XXX: Support soon.
stream = True
if stream:
response = client.messages.stream(**final_call_params)
else:
response = client.messages.create(**final_call_params)
return APICallResult(
response=response,
actual_streaming=stream,
actual_n=actual_n,
final_call_params=final_call_params,
)
@classmethod
def process_response(
cls, call_result: APICallResult, _invocation_origin: str, logger: Optional[Any] = None, tools: Optional[List[LMP]] = None,
) -> Tuple[List[Message], Dict[str, Any]]:
return final_call_params
def translate_from_provider(
self,
provider_response : Union[Stream[RawMessageStreamEvent], AnthropicMessage],
ell_call: EllCallParams,
provider_call_params: Dict[str, Any],
origin_id: Optional[str] = None,
logger: Optional[Callable[..., None]] = None,
) -> Tuple[List[Message], Metadata]:
usage = {}
tracked_results = []
metadata = {}
if call_result.actual_streaming:
if provider_call_params.get("stream", False):
content = []
current_block: Optional[Dict[str, Any]] = None
message_metadata = {}
with call_result.response as stream:
with cast(Stream[RawMessageStreamEvent], provider_response) as stream:
for chunk in stream:
if chunk.type == "message_start":
message_metadata = chunk.message.dict()
@@ -90,35 +82,26 @@ try:
if current_block is not None:
if current_block["type"] == "text":
current_block["content"] += chunk.delta.text
logger(chunk.delta.text)
elif chunk.type == "content_block_stop":
if current_block is not None:
if current_block["type"] == "text":
content.append(ContentBlock(text=_lstr(current_block["content"],origin_trace=_invocation_origin)))
content.append(ContentBlock(text=_lstr(current_block["content"],origin_trace=origin_id)))
elif current_block["type"] == "tool_use":
try:
final_cb = chunk.content_block
matching_tool = next(
(
tool
for tool in tools
if tool.__name__ == final_cb.name
),
None,
)
matching_tool = ell_call.get_tool_by_name(final_cb.name)
if matching_tool:
params = matching_tool.__ell_params_model__(
**final_cb.input
)
content.append(
ContentBlock(
tool_call=ToolCall(
tool=matching_tool,
tool_call_id=_lstr(
final_cb.id,origin_trace=_invocation_origin
final_cb.id,origin_trace=origin_id
),
params=params,
params=final_cb.input,
)
)
)
@@ -139,35 +122,8 @@ try:
elif chunk.type == "message_stop":
tracked_results.append(Message(role="assistant", content=content))
if logger and current_block:
if chunk.type == "text" and current_block["type"] == "text":
logger(chunk.text)
# print(chunk)
metadata = message_metadata
else:
# Non-streaming response processing (unchanged)
cbs = []
for content_block in call_result.response.content:
if content_block.type == "text":
cbs.append(ContentBlock(text=_lstr(content_block.text,origin_trace=_invocation_origin)))
elif content_block.type == "tool_use":
assert tools is not None, "Tools were not provided to the model when calling it and yet anthropic returned a tool use."
tool_call = ToolCall(
tool=next((t for t in tools if t.__name__ == content_block.name), None) ,
tool_call_id=content_block.id,
params=content_block.input
)
cbs.append(ContentBlock(tool_call=tool_call))
tracked_results.append(Message(role="assistant", content=cbs))
if logger:
logger(tracked_results[0].text)
usage = call_result.response.usage.dict() if call_result.response.usage else {}
metadata = call_result.response.model_dump()
del metadata["content"]
# process metadata for ell
# XXX: Unify an ell metadata format for ell studio.
@@ -178,28 +134,19 @@ try:
metadata["usage"] = usage
return tracked_results, metadata
@classmethod
def supports_streaming(cls) -> bool:
return True
@classmethod
def get_client_type(cls) -> Type:
return Anthropic
@staticmethod
def serialize_image_for_anthropic(img):
buffer = BytesIO()
img.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode()
register_provider(AnthropicProvider)
register_provider(AnthropicProvider(), Anthropic)
except ImportError:
pass
def serialize_image_for_anthropic(img):
buffer = BytesIO()
img.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode()
def content_block_to_anthropic_format(content_block: ContentBlock) -> Dict[str, Any]:
if content_block.image:
base64_image = AnthropicProvider.serialize_image_for_anthropic(content_block.image)
base64_image = serialize_image_for_anthropic(content_block.image)
return {
"type": "image",
"source": {

View File

@@ -18,11 +18,11 @@ try:
class OpenAIProvider(Provider):
dangerous_disable_validation = True
def provider_call_function(self, api_call_params : Optional[Dict[str, Any]] = None) -> Callable[..., Any]:
def provider_call_function(self, client : openai.Client, api_call_params : Optional[Dict[str, Any]] = None) -> Callable[..., Any]:
if api_call_params and api_call_params.get("response_format"):
return openai.beta.chat.completions.parse
return client.beta.chat.completions.parse
else:
return openai.chat.completions.create
return client.chat.completions.create
def translate_to_provider(self, ell_call : EllCallParams) -> Dict[str, Any]:
final_call_params = ell_call.api_params.copy()