shitty provider implementation to work from

This commit is contained in:
William Guss
2024-09-11 13:17:49 -07:00
parent 7023307834
commit 16c6a814c8
15 changed files with 648 additions and 201 deletions

View File

@@ -0,0 +1,153 @@
""""
Example originally from Instructor docs https://python.useinstructor.com/examples/knowledge_graph/
All rights reserved to the original author.
"""
from graphviz import Digraph
from pydantic import BaseModel, Field
from typing import List, Optional
class Node(BaseModel):
id: int
label: str
color: str
class Edge(BaseModel):
source: int
target: int
label: str
color: str = Field(description="The color of the edge. Defaults to black.")
class KnowledgeGraph(BaseModel):
nodes: Optional[List[Node]] = Field(..., default_factory=list)
edges: Optional[List[Edge]] = Field(..., default_factory=list)
def update(self, other: "KnowledgeGraph") -> "KnowledgeGraph":
"""Updates the current graph with the other graph, deduplicating nodes and edges."""
# Create dictionaries to store unique nodes and edges
unique_nodes = {node.id: node for node in self.nodes}
unique_edges = {(edge.source, edge.target, edge.label): edge for edge in self.edges}
# Update with nodes and edges from the other graph
for node in other.nodes:
unique_nodes[node.id] = node
for edge in other.edges:
unique_edges[(edge.source, edge.target, edge.label)] = edge
return KnowledgeGraph(
nodes=list(unique_nodes.values()),
edges=list(unique_edges.values()),
)
def draw(self, prefix: str = None):
dot = Digraph(comment="Knowledge Graph")
for node in self.nodes:
dot.node(str(node.id), node.label, color=node.color)
for edge in self.edges:
dot.edge(
str(edge.source), str(edge.target), label=edge.label, color=edge.color
)
dot.render(prefix, format="png", view=True)
import ell
@ell.complex(model="gpt-4o-2024-08-06", response_format=KnowledgeGraph)
def update_knowledge_graph(cur_state: KnowledgeGraph, inp: str, i: int, num_iterations: int):
return [
ell.system("""You are an iterative knowledge graph builder.
You are given the current state of the graph, and you must append the nodes and edges
to it Do not procide any duplcates and try to reuse nodes as much as possible."""),
ell.user(f"""Extract any new nodes and edges from the following:
# Part {i}/{num_iterations} of the input:
{inp}"""),
ell.user(f"""Here is the current state of the graph:
{cur_state.model_dump_json(indent=2)}""")
]
def generate_graph(input: List[str]) -> KnowledgeGraph:
cur_state = KnowledgeGraph()
num_iterations = len(input)
for i, inp in enumerate(input):
new_updates = update_knowledge_graph(cur_state, inp, i, num_iterations).parsed[0]
cur_state = cur_state.update(new_updates)
cur_state.draw(prefix=f"iteration_{i}")
return cur_state
if __name__ == "__main__":
ell.init(verbose=True, store='./logdir', autocommit=True)
generate_graph(["This is a test", "This is another test", "This is a third test"])
# Compare to: Original instructor example.
# def generate_graph(input: List[str]) -> KnowledgeGraph:
# cur_state = KnowledgeGraph()
# num_iterations = len(input)
# for i, inp in enumerate(input):
# new_updates = client.chat.completions.create(
# model="gpt-3.5-turbo-16k",
# messages=[
# {
# "role": "system",
# "content": """You are an iterative knowledge graph builder.
# You are given the current state of the graph, and you must append the nodes and edges
# to it Do not procide any duplcates and try to reuse nodes as much as possible.""",
# },
# {
# "role": "user",
# "content": f"""Extract any new nodes and edges from the following:
# # Part {i}/{num_iterations} of the input:
# {inp}""",
# },
# {
# "role": "user",
# "content": f"""Here is the current state of the graph:
# {cur_state.model_dump_json(indent=2)}""",
# },
# ],
# response_model=KnowledgeGraph,
# ) # type: ignore
# # Update the current state
# cur_state = cur_state.update(new_updates)
# cur_state.draw(prefix=f"iteration_{i}")
# return cur_state
# Bonus: Generate with a chat history.
# XXX: This illustrates the need for a dedicated chat type in ell.
# @ell.complex(model="gpt-4o-2024-08-06", response_format=KnowledgeGraph)
# def update_knowledge_graph_with_chat_history(cur_state: KnowledgeGraph, inp: str, i: int, num_iterations: int, chat_history):
# return [
# ell.system("""You are an iterative knowledge graph builder.
# You are given the current state of the graph, and you must append the nodes and edges
# to it Do not procide any duplcates and try to reuse nodes as much as possible."""),
# *chat_history,
# ell.user(f"""Extract any new nodes and edges from the following:
# # Part {i}/{num_iterations} of the input:
# {inp}"""),
# ell.user(f"""Here is the current state of the graph:
# {cur_state.model_dump_json(indent=2)}""")
# ]
# def generate_graph_with_chat_history(input: List[str]) -> KnowledgeGraph:
# chat_history = []
# cur_state = KnowledgeGraph()
# num_iterations = len(input)
# for i, inp in enumerate(input):
# new_updates = update_knowledge_graph_with_chat_history(cur_state, inp, i, num_iterations, chat_history)
# cur_state = cur_state.update(new_updates.parsed[0])
# chat_history.append(new_updates)
# cur_state.draw(prefix=f"iteration_{i}")
# return cur_state

View File

@@ -6,8 +6,6 @@ import numpy as np
from ell.stores.sql import SQLiteStore
import openai
@ell.simple(model="gpt-4o-mini")
def come_up_with_a_premise_for_a_joke_about(topic : str):
"""You are an incredibly funny comedian. Come up with a premise for a joke about topic"""

View File

@@ -11,6 +11,7 @@ from ell.types.message import system, user, assistant, Message, ContentBlock
from ell.__version__ import __version__
# Import all models
import ell.providers
import ell.models

View File

@@ -1,11 +1,12 @@
from functools import wraps
from typing import Dict, Any, Optional, Union
from typing import Dict, Any, Optional, Tuple, Union, Type
import openai
import logging
from contextlib import contextmanager
import threading
from pydantic import BaseModel, ConfigDict, Field
from ell.store import Store
from ell.provider import Provider
_config_logger = logging.getLogger(__name__)
class Config(BaseModel):
@@ -20,13 +21,14 @@ class Config(BaseModel):
default_lm_params: Dict[str, Any] = Field(default_factory=dict, description="Default parameters for language models.")
default_system_prompt: str = Field(default="You are a helpful AI assistant.", description="The default system prompt used for AI interactions.")
default_client: Optional[openai.Client] = Field(default=None, description="The default OpenAI client used when a specific model client is not found.")
providers: Dict[Type, Type[Provider]] = Field(default_factory=dict, description="A dictionary mapping client types to provider classes.")
def __init__(self, **data):
super().__init__(**data)
self._lock = threading.Lock()
self._local = threading.local()
def register_model(self, model_name: str, client: openai.Client) -> None:
def register_model(self, model_name: str, client: Any) -> None:
"""
Register an OpenAI client for a specific model name.
@@ -49,7 +51,7 @@ class Config(BaseModel):
return self.store is not None
@contextmanager
def model_registry_override(self, overrides: Dict[str, openai.Client]):
def model_registry_override(self, overrides: Dict[str, Any]):
"""
Temporarily override the model registry with new client mappings.
@@ -70,7 +72,7 @@ class Config(BaseModel):
finally:
self._local.stack.pop()
def get_client_for(self, model_name: str) -> Optional[openai.Client]:
def get_client_for(self, model_name: str) -> Tuple[Optional[Any], bool]:
"""
Get the OpenAI client for a specific model name.
@@ -154,6 +156,26 @@ class Config(BaseModel):
"""
self.default_client = client
def register_provider(self, provider_class: Type[Provider]) -> None:
"""
Register a provider class for a specific client type.
:param provider_class: The provider class to register.
:type provider_class: Type[AbstractProvider]
"""
with self._lock:
self.providers[provider_class.get_client_type()] = provider_class
def get_provider_for(self, client: Any) -> Optional[Type[Provider]]:
"""
Get the provider class for a specific client instance.
:param client: The client instance to get the provider for.
:type client: Any
:return: The provider class for the specified client, or None if not found.
:rtype: Optional[Type[AbstractProvider]]
"""
return next((provider for client_type, provider in self.providers.items() if isinstance(client, client_type)), None)
# Singleton instance
config = Config()
@@ -218,3 +240,6 @@ def set_default_system_prompt(*args, **kwargs) -> None:
return config.set_default_system_prompt(*args, **kwargs)
# You can add more helper functions here if needed
@wraps(config.register_provider)
def register_provider(*args, **kwargs) -> None:
return config.register_provider(*args, **kwargs)

View File

@@ -9,12 +9,10 @@ from ell.util.api import call
from ell.util.verbosity import compute_color, model_usage_logger_pre
import openai
from functools import wraps
from typing import Any, Dict, Optional, List, Callable, Union
def complex(model: str, client: Optional[openai.Client] = None, exempt_from_tracking=False, tools: Optional[List[Callable]] = None, post_callback: Optional[Callable] = None, **api_params):
def complex(model: str, client: Optional[Any] = None, exempt_from_tracking=False, tools: Optional[List[Callable]] = None, post_callback: Optional[Callable] = None, **api_params):
"""
A sophisticated language model programming decorator for complex LLM interactions.
@@ -228,7 +226,7 @@ def complex(model: str, client: Optional[openai.Client] = None, exempt_from_trac
def model_call(
*fn_args,
_invocation_origin : str = None,
client: Optional[openai.Client] = None,
client: Optional[Any] = None,
lm_params: Optional[LMPParams] = {},
invocation_api_params=False,
**fn_kwargs,

View File

@@ -1,11 +1,10 @@
from functools import wraps
from typing import Optional
from typing import Any, Optional
import openai
from ell.lmp.complex import complex
def simple(model: str, client: Optional[openai.Client] = None, exempt_from_tracking=False, **api_params):
def simple(model: str, client: Optional[Any] = None, exempt_from_tracking=False, **api_params):
"""
The fundamental unit of language model programming in ell.

56
src/ell/provider.py Normal file
View File

@@ -0,0 +1,56 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from ell.types import Message, ContentBlock, ToolCall
from ell.types._lstr import _lstr
import json
from dataclasses import dataclass
from ell.types.message import LMP
@dataclass
class APICallResult:
response: Any
actual_streaming: bool
actual_n: int
final_call_params: Dict[str, Any]
class Provider(ABC):
"""
Abstract base class for all providers. Providers are API interfaces to language models, not necessarily API providers.
For example, the OpenAI provider is an API interface to OpenAI's API but also to Ollama and Azure OpenAI.
"""
@classmethod
@abstractmethod
def call_model(
cls,
client: Any,
model: str,
messages: List[Any],
api_params: Dict[str, Any],
tools: Optional[list[LMP]] = None,
) -> APICallResult:
"""Make the API call to the language model and return the result along with actual streaming, n values, and final call parameters."""
pass
@classmethod
@abstractmethod
def process_response(
cls, call_result: APICallResult, _invocation_origin: str, logger: Optional[Any] = None, tools: Optional[List[LMP]] = None,
) -> Tuple[List[Message], Dict[str, Any]]:
"""Process the API response and convert it to ell format."""
pass
@classmethod
@abstractmethod
def supports_streaming(cls) -> bool:
"""Check if the provider supports streaming."""
pass
@classmethod
@abstractmethod
def get_client_type(cls) -> Type:
"""Return the type of client this provider supports."""
pass

View File

@@ -0,0 +1,9 @@
import ell.providers.openai
# import ell.providers.anthropic
# import ell.providers.groq
# import ell.providers.mistral
# import ell.providers.cohere
# import ell.providers.gemini
# import ell.providers.elevenlabs
# import ell.providers.replicate
# import ell.providers.huggingface

239
src/ell/providers/openai.py Normal file
View File

@@ -0,0 +1,239 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from ell.provider import APICallResult, Provider
from ell.types import Message, ContentBlock, ToolCall
from ell.types._lstr import _lstr
import json
from ell.configurator import config, register_provider
from ell.types.message import LMP
from ell.util.serialization import serialize_image
try:
import openai
class OpenAIProvider(Provider):
@staticmethod
def content_block_to_openai_format(content_block: ContentBlock) -> Dict[str, Any]:
if content_block.image:
base64_image = serialize_image(content_block.image)
return {
"type": "image_url",
"image_url": {
"url": base64_image
}
}
elif content_block.text:
return {
"type": "text",
"text": content_block.text
}
elif content_block.parsed:
return {
"type": "text",
"text": content_block.parsed.model_dump_json()
}
# Tool calls handled in message_to_openai_format.
#XXX: Feel free to refactor this.
else:
return None
@staticmethod
def message_to_openai_format(message: Message) -> Dict[str, Any]:
openai_message = {
"role": "tool" if message.tool_results else message.role,
"content": list(filter(None, [
OpenAIProvider.content_block_to_openai_format(c) for c in message.content
]))
}
if message.tool_calls:
try:
openai_message["tool_calls"] = [
{
"id": tool_call.tool_call_id,
"type": "function",
"function": {
"name": tool_call.tool.__name__,
"arguments": json.dumps(tool_call.params.model_dump())
}
} for tool_call in message.tool_calls
]
except TypeError as e:
print(f"Error serializing tool calls: {e}. Did you fully type your @ell.tool decorated functions?")
raise
openai_message["content"] = None # Set content to null when there are tool calls
if message.tool_results:
openai_message["tool_call_id"] = message.tool_results[0].tool_call_id
openai_message["content"] = message.tool_results[0].result[0].text
assert len(message.tool_results[0].result) == 1, "Tool result should only have one content block"
assert message.tool_results[0].result[0].type == "text", "Tool result should only have one text content block"
return openai_message
@classmethod
def call_model(
cls,
client: Any,
model: str,
messages: List[Message],
api_params: Dict[str, Any],
tools: Optional[list[LMP]] = None,
) -> APICallResult:
final_call_params = api_params.copy()
openai_messages = [cls.message_to_openai_format(message) for message in messages]
actual_n = api_params.get("n", 1)
final_call_params["model"] = model
final_call_params["messages"] = openai_messages
if final_call_params.get("response_format"):
final_call_params.pop("stream", None)
final_call_params.pop("stream_options", None)
response = client.beta.chat.completions.parse(**final_call_params)
else:
# Tools not workign with structured API
if tools:
final_call_params["tool_choice"] = "auto"
final_call_params["tools"] = [
{
"type": "function",
"function": {
"name": tool.__name__,
"description": tool.__doc__,
"parameters": tool.__ell_params_model__.model_json_schema(),
},
}
for tool in tools
]
final_call_params.pop("stream", None)
final_call_params.pop("stream_options", None)
else:
final_call_params["stream_options"] = {"include_usage": True}
final_call_params["stream"] = True
response = client.chat.completions.create(**final_call_params)
return APICallResult(
response=response,
actual_streaming=isinstance(response, openai.Stream),
actual_n=actual_n,
final_call_params=final_call_params,
)
@classmethod
def process_response(
cls, call_result: APICallResult, _invocation_origin: str, logger : Optional[Any] = None, tools: Optional[List[LMP]] = None,
) -> Tuple[List[Message], Dict[str, Any]]:
choices_progress = defaultdict(list)
api_params = call_result.final_call_params
metadata = {}
#XXX: Remove logger and refactor this API
if not call_result.actual_streaming:
response = [call_result.response]
else:
response = call_result.response
for chunk in response:
if hasattr(chunk, "usage") and chunk.usage:
metadata = chunk.to_dict()
if call_result.actual_streaming:
continue
for choice in chunk.choices:
choices_progress[choice.index].append(choice)
if choice.index == 0 and logger:
# print(choice, streaming)
logger(choice.delta.content if call_result.actual_streaming else
choice.message.content or getattr(choice.message, "refusal", ""), is_refusal=getattr(choice.message, "refusal", False) if not call_result.actual_streaming else False)
tracked_results = []
for _, choice_deltas in sorted(choices_progress.items(), key=lambda x: x[0]):
content = []
if call_result.actual_streaming:
text_content = "".join(
(choice.delta.content or "" for choice in choice_deltas)
)
if text_content:
content.append(
ContentBlock(
text=_lstr(
content=text_content, _origin_trace=_invocation_origin
)
)
)
else:
choice = choice_deltas[0].message
if choice.refusal:
raise ValueError(choice.refusal)
if api_params.get("response_format"):
content.append(ContentBlock(parsed=choice.parsed))
elif choice.content:
content.append(
ContentBlock(
text=_lstr(
content=choice.content, _origin_trace=_invocation_origin
)
)
)
if not call_result.actual_streaming and hasattr(choice, "tool_calls") and choice.tool_calls:
assert tools is not None, "Tools not provided, yet tool calls in response. Did you manually specify a tool spec without using ell.tool?"
for tool_call in choice.tool_calls:
matching_tool = next(
(
tool
for tool in tools
if tool.__name__ == tool_call.function.name
),
None,
)
if matching_tool:
params = matching_tool.__ell_params_model__(
**json.loads(tool_call.function.arguments)
)
content.append(
ContentBlock(
tool_call=ToolCall(
tool=matching_tool,
tool_call_id=_lstr(
tool_call.id, _origin_trace=_invocation_origin
),
params=params,
)
)
)
tracked_results.append(
Message(
role=(
choice.role
if not call_result.actual_streaming
else choice_deltas[0].delta.role
),
content=content,
)
)
return tracked_results, metadata
@classmethod
def supports_streaming(cls) -> bool:
return True
@classmethod
def get_client_type(cls) -> Type:
return openai.Client
@classmethod
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.register()
register_provider(OpenAIProvider)
except ImportError:
pass

View File

@@ -48,7 +48,7 @@ class ContentBlock(BaseModel):
image: Optional[Union[PILImage.Image, str, np.ndarray]] = Field(default=None)
audio: Optional[Union[np.ndarray, List[float]]] = Field(default=None)
tool_call: Optional[ToolCall] = Field(default=None)
parsed: Optional[Union[Type[BaseModel], BaseModel]] = Field(default=None)
parsed: Optional[BaseModel] = Field(default=None)
tool_result: Optional[ToolResult] = Field(default=None)
@model_validator(mode='after')
@@ -174,29 +174,6 @@ class ContentBlock(BaseModel):
return serialize_image(image)
def to_openai_content_block(self):
if self.image:
base64_image = self.serialize_image(self.image, None)
return {
"type": "image_url",
"image_url": {
"url": base64_image
}
}
elif self.text:
return {
"type": "text",
"text": self.text
}
elif self.parsed:
return {
"type": "text",
"json": self.parsed.model_dump_json()
}
else:
return None
def coerce_content_list(content: Union[str, List[ContentBlock], List[Union[str, ContentBlock, ToolCall, ToolResult, BaseModel]]] = None, **content_block_kwargs) -> List[ContentBlock]:
if not content:
content = [ContentBlock(**content_block_kwargs)]
@@ -309,36 +286,6 @@ class Message(BaseModel):
content = [c.tool_call.call_and_collect_as_message_block() for c in self.content if c.tool_call]
return Message(role="user", content=content)
def to_openai_message(self) -> Dict[str, Any]:
message = {
"role": "tool" if self.tool_results else self.role,
"content": list(filter(None, [
c.to_openai_content_block() for c in self.content
]))
}
if self.tool_calls:
message["tool_calls"] = [
{
"id": tool_call.tool_call_id,
"type": "function",
"function": {
"name": tool_call.tool.__name__,
"arguments": json.dumps(tool_call.params.model_dump())
}
} for tool_call in self.tool_calls
]
message["content"] = None # Set content to null when there are tool calls
if self.tool_results:
message["tool_call_id"] = self.tool_results[0].tool_call_id
# message["name"] = self.tool_results[0].tool_call_id.split('-')[0] # Assuming the tool name is the first part of the tool_call_id
message["content"] = self.tool_results[0].result[0].text
# Let';s assert no other type of content block in the tool result
assert len(self.tool_results[0].result) == 1, "Tool result should only have one content block"
assert self.tool_results[0].result[0].type == "text", "Tool result should only have one text content block"
return message
# HELPERS
def system(content: Union[str, List[ContentBlock]]) -> Message:
"""

View File

@@ -139,7 +139,7 @@ class InvocationContentsBase(SQLModel):
]
total_size = sum(
len(json.dumps(field, default=(lambda x: x.model_dump_json() if isinstance(x, BaseModel) else str(x))).encode('utf-8')) for field in json_fields if field is not None
len(json.dumps(field, default=(lambda x: json.dumps(x.model_dump(), default=str) if isinstance(x, BaseModel) else str(x))).encode('utf-8')) for field in json_fields if field is not None
)
# print("total_size", total_size)

View File

@@ -17,16 +17,16 @@ To fix this:
import ell
import openai
ell.lm(model, client=openai.Client(api_key=my_key))
ell.simple(model, client=openai.Client(api_key=my_key))
def {name}(...):
...
```
* Or explicitly specify the client when the calling the LMP:
```
ell.lm(model, client=openai.Client(api_key=my_key))(...)
ell.simple(model, client=openai.Client(api_key=my_key))(...)
```
""" if long else " at time of definition. Can be okay if custom client specified later! <TODO: add link to docs> ") + f"{Style.RESET_ALL}"
""" if long else " at time of definition. Can be okay if custom client specified later! https://docs.ell.so/core_concepts/models_and_api_clients.html ") + f"{Style.RESET_ALL}"
def _warnings(model, fn, default_client_from_decorator):
@@ -40,14 +40,14 @@ def _warnings(model, fn, default_client_from_decorator):
* If this is a mistake either specify a client explicitly in the decorator:
```python
import ell
ell.lm(model, client=my_client)
ell.simple(model, client=my_client)
def {fn.__name__}(...):
...
```
or explicitly specify the client when the calling the LMP:
```python
ell.lm(model, client=my_client)(...)
ell.simple(model, client=my_client)(...)
```
{Style.RESET_ALL}""")
elif (client_to_use := config.registry[model]) is None or not client_to_use.api_key:

View File

@@ -1,170 +1,63 @@
from functools import partial
import json
# import anthropic
from ell.configurator import config
import openai
from collections import defaultdict
from ell.types._lstr import _lstr
from ell.types import Message, ContentBlock, ToolCall
from typing import Any, Dict, Iterable, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type
from ell.types.message import LMP, LMPParams, MessageOrDict
from ell.util.verbosity import model_usage_logger_post_end, model_usage_logger_post_intermediate, model_usage_logger_post_start
from ell.util._warnings import _no_api_key_warning
from ell.provider import APICallResult, Provider
import logging
logger = logging.getLogger(__name__)
def process_messages_for_client(messages: list[Message], client: Any):
if isinstance(client, openai.Client):
return [
message.to_openai_message()
for message in messages]
# elif isinstance(client, anthropic.Anthropic):
# return messages
# XXX: or some such.
def call(
*,
model: str,
messages: list[Message],
api_params: Dict[str, Any],
tools: Optional[list[LMP]] = None,
client: Optional[openai.Client] = None,
client: Optional[Any] = None,
_invocation_origin: str,
_exempt_from_tracking: bool,
_logging_color=None,
_name: str = None,
) -> Tuple[Union[_lstr, Iterable[_lstr]], Optional[Dict[str, Any]]]:
_logging_color: Optional[str] = None,
_name: Optional[str] = None,
) -> Tuple[Union[Message, List[Message]], Dict[str, Any], Dict[str, Any]]:
"""
Helper function to run the language model with the provided messages and parameters.
"""
# Todo: Decide if the client specified via the context amanger default registry is the shit or if the cliennt specified via lmp invocation args are the hing.
if not client:
client, was_fallback = config.get_client_for(model)
if not client and not was_fallback:
# Someone registered you as None and you're trying to use this shit
raise RuntimeError(_no_api_key_warning(model, _name, '', long=True, error=True))
metadata = dict()
if client is None:
raise ValueError(f"No client found for model '{model}'. Ensure the model is registered using 'register_model' in 'config.py' or specify a client directly using the 'client' argument in the decorator or function call.")
if not client.api_key:
raise RuntimeError(_no_api_key_warning(model, _name, client, long=True, error=True))
# todo: add suupport for streaming apis that dont give a final usage in the api
# print(api_params)
if api_params.get("response_format", False):
model_call = client.beta.chat.completions.parse
api_params.pop("stream", None)
api_params.pop("stream_options", None)
elif tools:
model_call = client.chat.completions.create
api_params["tools"] = [
{
"type": "function",
"function": {
"name": tool.__name__,
"description": tool.__doc__,
"parameters": tool.__ell_params_model__.model_json_schema()
}
} for tool in tools
]
api_params["tool_choice"] = "auto"
api_params.pop("stream", None)
api_params.pop("stream_options", None)
else:
model_call = client.chat.completions.create
api_params["stream"] = True
api_params["stream_options"] = {"include_usage": True}
provider_class: Type[Provider] = config.get_provider_for(client)
client_safe_messages_messages = process_messages_for_client(messages, client)
# print(api_params)
model_result = model_call(
model=model, messages=client_safe_messages_messages, **api_params
)
streaming = api_params.get("stream", False)
if not streaming:
model_result = [model_result]
choices_progress = defaultdict(list)
n = api_params.get("n", 1)
# XXX: Could actually delete htis
call_result = provider_class.call_model(client, model, messages, api_params, tools)
if config.verbose and not _exempt_from_tracking:
model_usage_logger_post_start(_logging_color, n)
model_usage_logger_post_start(_logging_color, call_result.actual_n)
with model_usage_logger_post_intermediate(_logging_color, n) as _logger:
for chunk in model_result:
if hasattr(chunk, "usage") and chunk.usage:
# Todo: is this a good decision.
metadata = chunk.to_dict()
with model_usage_logger_post_intermediate(_logging_color, call_result.actual_n) as _logger:
if streaming:
continue
tracked_results, metadata = provider_class.process_response(call_result, _invocation_origin, _logger if config.verbose and not _exempt_from_tracking else None, tools)
for choice in chunk.choices:
choices_progress[choice.index].append(choice)
if config.verbose and choice.index == 0 and not _exempt_from_tracking:
# print(choice, streaming)
_logger(choice.delta.content if streaming else
choice.message.content or getattr(choice.message, "refusal", ""), is_refusal=getattr(choice.message, "refusal", False) if not streaming else False)
if config.verbose and not _exempt_from_tracking:
model_usage_logger_post_end()
n_choices = len(choices_progress)
# coerce the streaming into a final message type
tracked_results = []
for _, choice_deltas in sorted(choices_progress.items(), key=lambda x: x[0]):
content = []
# Handle text content
if streaming:
text_content = "".join((choice.delta.content or "" for choice in choice_deltas))
if text_content:
content.append(ContentBlock(
text=_lstr(content=text_content, _origin_trace=_invocation_origin)
))
else:
choice = choice_deltas[0].message
if choice.refusal:
raise ValueError(choice.refusal)
# XXX: is this the best practice? try catch a parser?
if api_params.get("response_format", False):
content.append(ContentBlock(
parsed=choice.parsed
))
elif choice.content:
content.append(ContentBlock(
text=_lstr(content=choice.content, _origin_trace=_invocation_origin)
))
# Handle tool calls
if not streaming and hasattr(choice, 'tool_calls'):
for tool_call in choice.tool_calls or []:
matching_tool = None
for tool in tools:
if tool.__name__ == tool_call.function.name:
matching_tool = tool
break
if matching_tool:
params = matching_tool.__ell_params_model__(**json.loads(tool_call.function.arguments))
content.append(ContentBlock(
tool_call=ToolCall(tool=matching_tool, tool_call_id=_lstr(tool_call.id, _origin_trace=_invocation_origin), params=params)
))
tracked_results.append(Message(
role=choice.role if not streaming else choice_deltas[0].delta.role,
content=content
))
api_params = dict(model=model, messages=client_safe_messages_messages, api_params=api_params)
return tracked_results[0] if n_choices == 1 else tracked_results, api_params, metadata
return (tracked_results[0] if len(tracked_results) == 1 else tracked_results), call_result.final_call_params, metadata

View File

@@ -43,7 +43,7 @@ def check_version_and_log():
import ell
try:
response = requests.get("https://docs.ell.so/_static/ell_version.txt", timeout=0.5)
response = requests.get("https://docs.ell.so/_static/ell_version.txt", timeout=0.1)
if response.status_code == 200:
latest_version = response.text.strip()
if latest_version != ell.__version__:

View File

@@ -0,0 +1,129 @@
import pytest
from unittest.mock import Mock, patch
from ell.provider import APICallResult
from ell.providers.openai import OpenAIProvider
from ell.types import Message, ContentBlock, ToolCall
from ell.types.message import LMP, ToolResult
from pydantic import BaseModel
import json
import ell
class DummyParams(BaseModel):
param1: str
param2: int
@pytest.fixture
def mock_openai_client():
return Mock()
import openai
def test_content_block_to_openai_format():
# Test text content
text_block = ContentBlock(text="Hello, world!")
assert OpenAIProvider.content_block_to_openai_format(text_block) == {
"type": "text",
"text": "Hello, world!"
}
# Test parsed content
class DummyParsed(BaseModel):
field: str
parsed_block = ContentBlock(parsed=DummyParsed(field="value"))
res = OpenAIProvider.content_block_to_openai_format(parsed_block)
assert res["type"] == "text"
assert (res["text"]) == '{"field":"value"}'
# Test image content (mocked)
with patch('ell.providers.openai.serialize_image', return_value="base64_image_data"):
# Test random image content
import numpy as np
from PIL import Image
# Generate a random image
random_image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
pil_image = Image.fromarray(random_image)
with patch('ell.providers.openai.serialize_image', return_value="random_base64_image_data"):
random_image_block = ContentBlock(image=pil_image)
assert OpenAIProvider.content_block_to_openai_format(random_image_block) == {
"type": "image_url",
"image_url": {
"url": "random_base64_image_data"
}
}
def test_message_to_openai_format():
# Test simple message
simple_message = Message(role="user", content=[ContentBlock(text="Hello")])
assert OpenAIProvider.message_to_openai_format(simple_message) == {
"role": "user",
"content": [{"type": "text", "text": "Hello"}]
}
# Test message with tool calls
def dummy_tool(param1: str, param2: int): pass
tool_call = ToolCall(tool=dummy_tool, tool_call_id="123", params=DummyParams(param1="test", param2=42))
tool_message = Message(role="assistant", content=[tool_call])
formatted = OpenAIProvider.message_to_openai_format(tool_message)
assert formatted["role"] == "assistant"
assert formatted["content"] is None
assert len(formatted["tool_calls"]) == 1
assert formatted["tool_calls"][0]["id"] == "123"
assert formatted["tool_calls"][0]["function"]["name"] == "dummy_tool"
assert json.loads(formatted["tool_calls"][0]["function"]["arguments"]) == {"param1": "test", "param2": 42}
# Test message with tool results
tool_result_message = Message(
role="user",
content=[ToolResult(tool_call_id="123", result=[ContentBlock(text="Tool output")])],
)
formatted = OpenAIProvider.message_to_openai_format(tool_result_message)
assert formatted["role"] == "tool"
assert formatted["tool_call_id"] == "123"
assert formatted["content"] == "Tool output"
def test_call_model(mock_openai_client):
messages = [Message(role="user", content=[ContentBlock(text="Hello")], refusal=None)]
api_params = {"temperature": 0.7}
# Mock the client's chat.completions.create method
mock_openai_client.chat.completions.create.return_value = Mock(choices=[Mock(message=Mock(content="Response", refusal=None))])
@ell.tool()
def dummy_tool(param1: str, param2: int): pass
result = OpenAIProvider.call_model(mock_openai_client, "gpt-3.5-turbo", messages, api_params, tools=[dummy_tool])
assert isinstance(result, APICallResult)
assert not "stream" in result.final_call_params
assert not result.actual_streaming
assert result.actual_n == 1
assert "messages" in result.final_call_params
assert result.final_call_params["model"] == "gpt-3.5-turbo"
def test_process_response():
# Mock APICallResult
mock_response = Mock(
choices=[Mock(message=Mock(role="assistant", content="Hello, world!", refusal=None, tool_calls=None))]
)
call_result = APICallResult(
response=mock_response,
actual_streaming=False,
actual_n=1,
final_call_params={}
)
processed_messages, metadata = OpenAIProvider.process_response(call_result, "test_origin")
assert len(processed_messages) == 1
assert processed_messages[0].role == "assistant"
assert len(processed_messages[0].content) == 1
assert processed_messages[0].content[0].text == "Hello, world!"
def test_supports_streaming():
assert OpenAIProvider.supports_streaming() == True
# Add more tests as needed for other methods and edge cases