somewhat working calude

This commit is contained in:
William Guss
2024-09-11 14:08:34 -07:00
parent 16c6a814c8
commit 51202c24d5
8 changed files with 265 additions and 12 deletions

11
examples/claude.py Normal file
View File

@@ -0,0 +1,11 @@
import ell
@ell.simple(model="claude-3-opus-20240229", max_tokens=100)
def hello_from_claude():
"""You are an AI assistant. Your task is to respond to the user's message with a friendly greeting."""
return "Say hello to the world!"
if __name__ == "__main__":
ell.init(verbose=True, store="./logdir", autocommit=True)
hello_from_claude()

View File

@@ -7,4 +7,5 @@ For example, to register an OpenAI model:
"""
import ell.models.openai
import ell.models.anthropic
import ell.models.ollama

View File

@@ -0,0 +1,44 @@
from ell.configurator import config
import logging
logger = logging.getLogger(__name__)
try:
import anthropic
def register(client: anthropic.Anthropic):
"""
Register Anthropic models with the provided client.
This function takes an Anthropic client and registers various Anthropic models
with the global configuration. It allows the system to use these models
for different AI tasks.
Args:
client (anthropic.Anthropic): An instance of the Anthropic client to be used
for model registration.
Note:
The function doesn't return anything but updates the global
configuration with the registered models.
"""
model_data = [
('claude-3-opus-20240229', 'anthropic'),
('claude-3-sonnet-20240229', 'anthropic'),
('claude-3-haiku-20240307', 'anthropic'),
('claude-3-5-sonnet-20240620', 'anthropic'),
]
for model_id, owned_by in model_data:
config.register_model(model_id, client)
try:
default_client = anthropic.Anthropic()
register(default_client)
except Exception as e:
# logger.warning(f"Failed to create default Anthropic client: {e}")
pass
except ImportError:
pass

View File

@@ -1,5 +1,5 @@
import ell.providers.openai
# import ell.providers.anthropic
import ell.providers.anthropic
# import ell.providers.groq
# import ell.providers.mistral
# import ell.providers.cohere

View File

@@ -0,0 +1,198 @@
from typing import Any, Dict, List, Optional, Tuple, Type
from ell.provider import APICallResult, Provider
from ell.types import Message, ContentBlock, ToolCall
from ell.types._lstr import _lstr
from ell.types.message import LMP
from ell.configurator import register_provider
from ell.util.serialization import serialize_image
import base64
from io import BytesIO
import json
try:
import anthropic
from anthropic import Anthropic
class AnthropicProvider(Provider):
@classmethod
def call_model(
cls,
client: Anthropic,
model: str,
messages: List[Message],
api_params: Dict[str, Any],
tools: Optional[list[LMP]] = None,
) -> APICallResult:
final_call_params = api_params.copy()
assert final_call_params.get("max_tokens") is not None, f"max_tokens is required for anthropic calls, pass it to the @ell.simple/complex decorator, e.g. @ell.simple(..., max_tokens=your_max_tokens) or pass it to the model directly as a parameter when calling your LMP: your_lmp(..., lm_params=({{'max_tokens': your_max_tokens}}))."
anthropic_messages = [message_to_anthropic_format(message) for message in messages]
system_message = None
if anthropic_messages and anthropic_messages[0]["role"] == "system":
system_message = anthropic_messages.pop(0)
if system_message:
final_call_params["system"] = system_message["content"][0]["text"]
actual_n = api_params.get("n", 1)
final_call_params["model"] = model
final_call_params["messages"] = anthropic_messages
if tools:
final_call_params["tools"] = [
{
"name": tool.__name__,
"description": tool.__doc__,
"input_schema": tool.__ell_params_model__.model_json_schema(),
}
for tool in tools
]
# Streaming unsupported.
stream = final_call_params.pop("stream", False)
if stream:
response = client.messages.stream(**final_call_params)
else:
response = client.messages.create(**final_call_params)
return APICallResult(
response=response,
actual_streaming=stream,
actual_n=actual_n,
final_call_params=final_call_params,
)
@classmethod
def process_response(
cls, call_result: APICallResult, _invocation_origin: str, logger: Optional[Any] = None, tools: Optional[List[LMP]] = None,
) -> Tuple[List[Message], Dict[str, Any]]:
metadata = {}
tracked_results = []
if call_result.actual_streaming:
content = []
current_block = None
message_metadata = {}
for chunk in call_result.response:
if chunk.type == "message_start":
message_metadata = chunk.message.dict()
message_metadata.pop("content", None) # Remove content as we'll build it separately
elif chunk.type == "content_block_start":
current_block = {"type": chunk.content_block.type, "content": ""}
elif chunk.type == "content_block_delta":
if current_block["type"] == "text":
current_block["content"] += chunk.delta.text
elif current_block["type"] == "tool_use":
current_block.setdefault("input", "")
current_block["input"] += chunk.delta.partial_json
elif chunk.type == "content_block_stop":
if current_block["type"] == "text":
content.append(ContentBlock(text=_lstr(current_block["content"], _origin_trace=_invocation_origin)))
elif current_block["type"] == "tool_use":
try:
tool_input = json.loads(current_block["input"])
tool_call = ToolCall(
tool=next((t for t in tools if t.__name__ == current_block["name"]), None),
tool_call_id=current_block.get("id"),
params=tool_input
)
content.append(ContentBlock(tool_call=tool_call))
except json.JSONDecodeError:
# Handle partial JSON if necessary
pass
current_block = None
elif chunk.type == "message_delta":
message_metadata.update(chunk.delta.dict())
if chunk.usage:
metadata.update(chunk.usage.dict())
elif chunk.type == "message_stop":
tracked_results.append(Message(role="assistant", content=content, **message_metadata))
if logger and current_block and current_block["type"] == "text":
logger(chunk.delta.text)
else:
# Non-streaming response processing (unchanged)
cbs = []
for content_block in call_result.response.content:
if content_block.type == "text":
cbs.append(ContentBlock(text=_lstr(content_block.text, _origin_trace=_invocation_origin)))
elif content_block.type == "tool_use":
assert tools is not None, "Tools were not provided to the model when calling it and yet anthropic returned a tool use."
tool_call = ToolCall(
tool=next((t for t in tools if t.__name__ == content_block.name), None),
tool_call_id=content_block.id,
params=content_block.input
)
cbs.append(ContentBlock(tool_call=tool_call))
tracked_results.append(Message(role="assistant", content=cbs))
if logger:
logger(tracked_results[0].text)
metadata = call_result.response.usage.dict() if call_result.response.usage else {}
return tracked_results, metadata
@classmethod
def supports_streaming(cls) -> bool:
return True
@classmethod
def get_client_type(cls) -> Type:
return Anthropic
@staticmethod
def serialize_image_for_anthropic(img):
buffer = BytesIO()
img.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode()
register_provider(AnthropicProvider)
except ImportError:
pass
"""
HELPERS
"""
def content_block_to_anthropic_format(content_block: ContentBlock) -> Dict[str, Any]:
if content_block.image:
base64_image = AnthropicProvider.serialize_image_for_anthropic(content_block.image)
return {
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": base64_image
}
}
elif content_block.text:
return {
"type": "text",
"text": content_block.text
}
elif content_block.parsed:
return {
"type": "text",
"text": json.dumps(content_block.parsed.model_dump())
}
else:
return None
def message_to_anthropic_format(message: Message) -> Dict[str, Any]:
anthropic_message = {
"role": message.role,
"content": list(filter(None, [
content_block_to_anthropic_format(c) for c in message.content
]))
}
return anthropic_message

View File

@@ -12,6 +12,8 @@ from ell.util.serialization import serialize_image
try:
import openai
class OpenAIProvider(Provider):
# XXX: This content block conversion etc might need to happen on a per model basis for providers like groq etc. We will think about this at a future date.
@staticmethod
def content_block_to_openai_format(content_block: ContentBlock) -> Dict[str, Any]:
if content_block.image:
@@ -229,10 +231,6 @@ try:
def get_client_type(cls) -> Type:
return openai.Client
@classmethod
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.register()
register_provider(OpenAIProvider)
except ImportError:

View File

@@ -8,23 +8,25 @@ logger = logging.getLogger(__name__)
def _no_api_key_warning(model, name, client_to_use, long=False, error=False):
color = Fore.RED if error else Fore.LIGHTYELLOW_EX
prefix = "ERROR" if error else "WARNING"
return f"""{color}{prefix}: No API key found for model `{model}` used by LMP `{name}` using client `{client_to_use}`""" + (""".
client_to_use_name = client_to_use.__class__.__name__
client_to_use_module = client_to_use.__class__.__module__
return f"""{color}{prefix}: No API key found for model `{model}` used by LMP `{name}` using client `{client_to_use}`""" + (f""".
To fix this:
* Or, set your API key in the environment variable `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, etc.
* Set your API key in the appropriate environment variable for your chosen provider
* Or, specify a client explicitly in the decorator:
```
import ell
import openai
from {client_to_use_module} import {client_to_use_name}
ell.simple(model, client=openai.Client(api_key=my_key))
@ell.simple(model="{model}", client={client_to_use_name}(api_key=your_api_key))
def {name}(...):
...
```
* Or explicitly specify the client when the calling the LMP:
* Or explicitly specify the client when calling the LMP:
```
ell.simple(model, client=openai.Client(api_key=my_key))(...)
{name}(..., client={client_to_use_name}(api_key=your_api_key))
```
""" if long else " at time of definition. Can be okay if custom client specified later! https://docs.ell.so/core_concepts/models_and_api_clients.html ") + f"{Style.RESET_ALL}"

View File

@@ -52,7 +52,6 @@ def call(
model_usage_logger_post_start(_logging_color, call_result.actual_n)
with model_usage_logger_post_intermediate(_logging_color, call_result.actual_n) as _logger:
tracked_results, metadata = provider_class.process_response(call_result, _invocation_origin, _logger if config.verbose and not _exempt_from_tracking else None, tools)