mirror of
https://github.com/MadcowD/ell.git
synced 2024-09-22 16:14:36 +03:00
somewhat working calude
This commit is contained in:
11
examples/claude.py
Normal file
11
examples/claude.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import ell
|
||||
|
||||
@ell.simple(model="claude-3-opus-20240229", max_tokens=100)
|
||||
def hello_from_claude():
|
||||
"""You are an AI assistant. Your task is to respond to the user's message with a friendly greeting."""
|
||||
return "Say hello to the world!"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ell.init(verbose=True, store="./logdir", autocommit=True)
|
||||
hello_from_claude()
|
||||
@@ -7,4 +7,5 @@ For example, to register an OpenAI model:
|
||||
"""
|
||||
|
||||
import ell.models.openai
|
||||
import ell.models.anthropic
|
||||
import ell.models.ollama
|
||||
44
src/ell/models/anthropic.py
Normal file
44
src/ell/models/anthropic.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from ell.configurator import config
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
|
||||
def register(client: anthropic.Anthropic):
|
||||
"""
|
||||
Register Anthropic models with the provided client.
|
||||
|
||||
This function takes an Anthropic client and registers various Anthropic models
|
||||
with the global configuration. It allows the system to use these models
|
||||
for different AI tasks.
|
||||
|
||||
Args:
|
||||
client (anthropic.Anthropic): An instance of the Anthropic client to be used
|
||||
for model registration.
|
||||
|
||||
Note:
|
||||
The function doesn't return anything but updates the global
|
||||
configuration with the registered models.
|
||||
"""
|
||||
model_data = [
|
||||
('claude-3-opus-20240229', 'anthropic'),
|
||||
('claude-3-sonnet-20240229', 'anthropic'),
|
||||
('claude-3-haiku-20240307', 'anthropic'),
|
||||
('claude-3-5-sonnet-20240620', 'anthropic'),
|
||||
]
|
||||
for model_id, owned_by in model_data:
|
||||
config.register_model(model_id, client)
|
||||
|
||||
try:
|
||||
default_client = anthropic.Anthropic()
|
||||
register(default_client)
|
||||
except Exception as e:
|
||||
# logger.warning(f"Failed to create default Anthropic client: {e}")
|
||||
pass
|
||||
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
@@ -1,5 +1,5 @@
|
||||
import ell.providers.openai
|
||||
# import ell.providers.anthropic
|
||||
import ell.providers.anthropic
|
||||
# import ell.providers.groq
|
||||
# import ell.providers.mistral
|
||||
# import ell.providers.cohere
|
||||
|
||||
198
src/ell/providers/anthropic.py
Normal file
198
src/ell/providers/anthropic.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from ell.provider import APICallResult, Provider
|
||||
from ell.types import Message, ContentBlock, ToolCall
|
||||
from ell.types._lstr import _lstr
|
||||
from ell.types.message import LMP
|
||||
from ell.configurator import register_provider
|
||||
from ell.util.serialization import serialize_image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import json
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
from anthropic import Anthropic
|
||||
|
||||
class AnthropicProvider(Provider):
|
||||
@classmethod
|
||||
def call_model(
|
||||
cls,
|
||||
client: Anthropic,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
api_params: Dict[str, Any],
|
||||
tools: Optional[list[LMP]] = None,
|
||||
) -> APICallResult:
|
||||
final_call_params = api_params.copy()
|
||||
assert final_call_params.get("max_tokens") is not None, f"max_tokens is required for anthropic calls, pass it to the @ell.simple/complex decorator, e.g. @ell.simple(..., max_tokens=your_max_tokens) or pass it to the model directly as a parameter when calling your LMP: your_lmp(..., lm_params=({{'max_tokens': your_max_tokens}}))."
|
||||
|
||||
anthropic_messages = [message_to_anthropic_format(message) for message in messages]
|
||||
system_message = None
|
||||
if anthropic_messages and anthropic_messages[0]["role"] == "system":
|
||||
system_message = anthropic_messages.pop(0)
|
||||
|
||||
|
||||
if system_message:
|
||||
final_call_params["system"] = system_message["content"][0]["text"]
|
||||
|
||||
|
||||
actual_n = api_params.get("n", 1)
|
||||
final_call_params["model"] = model
|
||||
final_call_params["messages"] = anthropic_messages
|
||||
|
||||
if tools:
|
||||
final_call_params["tools"] = [
|
||||
{
|
||||
"name": tool.__name__,
|
||||
"description": tool.__doc__,
|
||||
"input_schema": tool.__ell_params_model__.model_json_schema(),
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
# Streaming unsupported.
|
||||
stream = final_call_params.pop("stream", False)
|
||||
if stream:
|
||||
response = client.messages.stream(**final_call_params)
|
||||
else:
|
||||
response = client.messages.create(**final_call_params)
|
||||
|
||||
return APICallResult(
|
||||
response=response,
|
||||
actual_streaming=stream,
|
||||
actual_n=actual_n,
|
||||
final_call_params=final_call_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def process_response(
|
||||
cls, call_result: APICallResult, _invocation_origin: str, logger: Optional[Any] = None, tools: Optional[List[LMP]] = None,
|
||||
) -> Tuple[List[Message], Dict[str, Any]]:
|
||||
metadata = {}
|
||||
tracked_results = []
|
||||
|
||||
if call_result.actual_streaming:
|
||||
content = []
|
||||
current_block = None
|
||||
message_metadata = {}
|
||||
|
||||
for chunk in call_result.response:
|
||||
if chunk.type == "message_start":
|
||||
message_metadata = chunk.message.dict()
|
||||
message_metadata.pop("content", None) # Remove content as we'll build it separately
|
||||
|
||||
elif chunk.type == "content_block_start":
|
||||
current_block = {"type": chunk.content_block.type, "content": ""}
|
||||
|
||||
elif chunk.type == "content_block_delta":
|
||||
if current_block["type"] == "text":
|
||||
current_block["content"] += chunk.delta.text
|
||||
elif current_block["type"] == "tool_use":
|
||||
current_block.setdefault("input", "")
|
||||
current_block["input"] += chunk.delta.partial_json
|
||||
|
||||
elif chunk.type == "content_block_stop":
|
||||
if current_block["type"] == "text":
|
||||
content.append(ContentBlock(text=_lstr(current_block["content"], _origin_trace=_invocation_origin)))
|
||||
elif current_block["type"] == "tool_use":
|
||||
try:
|
||||
tool_input = json.loads(current_block["input"])
|
||||
tool_call = ToolCall(
|
||||
tool=next((t for t in tools if t.__name__ == current_block["name"]), None),
|
||||
tool_call_id=current_block.get("id"),
|
||||
params=tool_input
|
||||
)
|
||||
content.append(ContentBlock(tool_call=tool_call))
|
||||
except json.JSONDecodeError:
|
||||
# Handle partial JSON if necessary
|
||||
pass
|
||||
current_block = None
|
||||
|
||||
elif chunk.type == "message_delta":
|
||||
message_metadata.update(chunk.delta.dict())
|
||||
if chunk.usage:
|
||||
metadata.update(chunk.usage.dict())
|
||||
|
||||
elif chunk.type == "message_stop":
|
||||
tracked_results.append(Message(role="assistant", content=content, **message_metadata))
|
||||
|
||||
if logger and current_block and current_block["type"] == "text":
|
||||
logger(chunk.delta.text)
|
||||
|
||||
else:
|
||||
# Non-streaming response processing (unchanged)
|
||||
cbs = []
|
||||
for content_block in call_result.response.content:
|
||||
if content_block.type == "text":
|
||||
cbs.append(ContentBlock(text=_lstr(content_block.text, _origin_trace=_invocation_origin)))
|
||||
elif content_block.type == "tool_use":
|
||||
assert tools is not None, "Tools were not provided to the model when calling it and yet anthropic returned a tool use."
|
||||
tool_call = ToolCall(
|
||||
tool=next((t for t in tools if t.__name__ == content_block.name), None),
|
||||
tool_call_id=content_block.id,
|
||||
params=content_block.input
|
||||
)
|
||||
cbs.append(ContentBlock(tool_call=tool_call))
|
||||
tracked_results.append(Message(role="assistant", content=cbs))
|
||||
if logger:
|
||||
logger(tracked_results[0].text)
|
||||
|
||||
metadata = call_result.response.usage.dict() if call_result.response.usage else {}
|
||||
|
||||
return tracked_results, metadata
|
||||
|
||||
@classmethod
|
||||
def supports_streaming(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_client_type(cls) -> Type:
|
||||
return Anthropic
|
||||
|
||||
@staticmethod
|
||||
def serialize_image_for_anthropic(img):
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode()
|
||||
|
||||
register_provider(AnthropicProvider)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
"""
|
||||
HELPERS
|
||||
"""
|
||||
def content_block_to_anthropic_format(content_block: ContentBlock) -> Dict[str, Any]:
|
||||
if content_block.image:
|
||||
base64_image = AnthropicProvider.serialize_image_for_anthropic(content_block.image)
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": base64_image
|
||||
}
|
||||
}
|
||||
elif content_block.text:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": content_block.text
|
||||
}
|
||||
elif content_block.parsed:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": json.dumps(content_block.parsed.model_dump())
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def message_to_anthropic_format(message: Message) -> Dict[str, Any]:
|
||||
|
||||
anthropic_message = {
|
||||
"role": message.role,
|
||||
"content": list(filter(None, [
|
||||
content_block_to_anthropic_format(c) for c in message.content
|
||||
]))
|
||||
}
|
||||
return anthropic_message
|
||||
@@ -12,6 +12,8 @@ from ell.util.serialization import serialize_image
|
||||
try:
|
||||
import openai
|
||||
class OpenAIProvider(Provider):
|
||||
|
||||
# XXX: This content block conversion etc might need to happen on a per model basis for providers like groq etc. We will think about this at a future date.
|
||||
@staticmethod
|
||||
def content_block_to_openai_format(content_block: ContentBlock) -> Dict[str, Any]:
|
||||
if content_block.image:
|
||||
@@ -229,10 +231,6 @@ try:
|
||||
def get_client_type(cls) -> Type:
|
||||
return openai.Client
|
||||
|
||||
@classmethod
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls.register()
|
||||
|
||||
register_provider(OpenAIProvider)
|
||||
except ImportError:
|
||||
|
||||
@@ -8,23 +8,25 @@ logger = logging.getLogger(__name__)
|
||||
def _no_api_key_warning(model, name, client_to_use, long=False, error=False):
|
||||
color = Fore.RED if error else Fore.LIGHTYELLOW_EX
|
||||
prefix = "ERROR" if error else "WARNING"
|
||||
return f"""{color}{prefix}: No API key found for model `{model}` used by LMP `{name}` using client `{client_to_use}`""" + (""".
|
||||
client_to_use_name = client_to_use.__class__.__name__
|
||||
client_to_use_module = client_to_use.__class__.__module__
|
||||
return f"""{color}{prefix}: No API key found for model `{model}` used by LMP `{name}` using client `{client_to_use}`""" + (f""".
|
||||
|
||||
To fix this:
|
||||
* Or, set your API key in the environment variable `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, etc.
|
||||
* Set your API key in the appropriate environment variable for your chosen provider
|
||||
* Or, specify a client explicitly in the decorator:
|
||||
```
|
||||
import ell
|
||||
import openai
|
||||
from {client_to_use_module} import {client_to_use_name}
|
||||
|
||||
ell.simple(model, client=openai.Client(api_key=my_key))
|
||||
@ell.simple(model="{model}", client={client_to_use_name}(api_key=your_api_key))
|
||||
def {name}(...):
|
||||
...
|
||||
```
|
||||
* Or explicitly specify the client when the calling the LMP:
|
||||
* Or explicitly specify the client when calling the LMP:
|
||||
|
||||
```
|
||||
ell.simple(model, client=openai.Client(api_key=my_key))(...)
|
||||
{name}(..., client={client_to_use_name}(api_key=your_api_key))
|
||||
```
|
||||
""" if long else " at time of definition. Can be okay if custom client specified later! https://docs.ell.so/core_concepts/models_and_api_clients.html ") + f"{Style.RESET_ALL}"
|
||||
|
||||
|
||||
@@ -52,7 +52,6 @@ def call(
|
||||
model_usage_logger_post_start(_logging_color, call_result.actual_n)
|
||||
|
||||
with model_usage_logger_post_intermediate(_logging_color, call_result.actual_n) as _logger:
|
||||
|
||||
tracked_results, metadata = provider_class.process_response(call_result, _invocation_origin, _logger if config.verbose and not _exempt_from_tracking else None, tools)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user