mirror of
https://github.com/MadcowD/ell.git
synced 2024-09-22 16:14:36 +03:00
shitty provider implementation to work from
This commit is contained in:
153
examples/future/knowledge_graph.py
Normal file
153
examples/future/knowledge_graph.py
Normal file
@@ -0,0 +1,153 @@
|
||||
""""
|
||||
Example originally from Instructor docs https://python.useinstructor.com/examples/knowledge_graph/
|
||||
All rights reserved to the original author.
|
||||
"""
|
||||
|
||||
from graphviz import Digraph
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class Node(BaseModel):
|
||||
id: int
|
||||
label: str
|
||||
color: str
|
||||
|
||||
|
||||
class Edge(BaseModel):
|
||||
source: int
|
||||
target: int
|
||||
label: str
|
||||
color: str = Field(description="The color of the edge. Defaults to black.")
|
||||
|
||||
|
||||
class KnowledgeGraph(BaseModel):
|
||||
nodes: Optional[List[Node]] = Field(..., default_factory=list)
|
||||
edges: Optional[List[Edge]] = Field(..., default_factory=list)
|
||||
|
||||
def update(self, other: "KnowledgeGraph") -> "KnowledgeGraph":
|
||||
"""Updates the current graph with the other graph, deduplicating nodes and edges."""
|
||||
# Create dictionaries to store unique nodes and edges
|
||||
unique_nodes = {node.id: node for node in self.nodes}
|
||||
unique_edges = {(edge.source, edge.target, edge.label): edge for edge in self.edges}
|
||||
|
||||
# Update with nodes and edges from the other graph
|
||||
for node in other.nodes:
|
||||
unique_nodes[node.id] = node
|
||||
for edge in other.edges:
|
||||
unique_edges[(edge.source, edge.target, edge.label)] = edge
|
||||
|
||||
return KnowledgeGraph(
|
||||
nodes=list(unique_nodes.values()),
|
||||
edges=list(unique_edges.values()),
|
||||
)
|
||||
|
||||
def draw(self, prefix: str = None):
|
||||
dot = Digraph(comment="Knowledge Graph")
|
||||
|
||||
for node in self.nodes:
|
||||
dot.node(str(node.id), node.label, color=node.color)
|
||||
|
||||
for edge in self.edges:
|
||||
dot.edge(
|
||||
str(edge.source), str(edge.target), label=edge.label, color=edge.color
|
||||
)
|
||||
dot.render(prefix, format="png", view=True)
|
||||
|
||||
|
||||
import ell
|
||||
|
||||
@ell.complex(model="gpt-4o-2024-08-06", response_format=KnowledgeGraph)
|
||||
def update_knowledge_graph(cur_state: KnowledgeGraph, inp: str, i: int, num_iterations: int):
|
||||
return [
|
||||
ell.system("""You are an iterative knowledge graph builder.
|
||||
You are given the current state of the graph, and you must append the nodes and edges
|
||||
to it Do not procide any duplcates and try to reuse nodes as much as possible."""),
|
||||
ell.user(f"""Extract any new nodes and edges from the following:
|
||||
# Part {i}/{num_iterations} of the input:
|
||||
|
||||
{inp}"""),
|
||||
ell.user(f"""Here is the current state of the graph:
|
||||
{cur_state.model_dump_json(indent=2)}""")
|
||||
]
|
||||
|
||||
def generate_graph(input: List[str]) -> KnowledgeGraph:
|
||||
cur_state = KnowledgeGraph()
|
||||
num_iterations = len(input)
|
||||
for i, inp in enumerate(input):
|
||||
new_updates = update_knowledge_graph(cur_state, inp, i, num_iterations).parsed[0]
|
||||
cur_state = cur_state.update(new_updates)
|
||||
cur_state.draw(prefix=f"iteration_{i}")
|
||||
return cur_state
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ell.init(verbose=True, store='./logdir', autocommit=True)
|
||||
generate_graph(["This is a test", "This is another test", "This is a third test"])
|
||||
|
||||
# Compare to: Original instructor example.
|
||||
# def generate_graph(input: List[str]) -> KnowledgeGraph:
|
||||
# cur_state = KnowledgeGraph()
|
||||
# num_iterations = len(input)
|
||||
# for i, inp in enumerate(input):
|
||||
# new_updates = client.chat.completions.create(
|
||||
# model="gpt-3.5-turbo-16k",
|
||||
# messages=[
|
||||
# {
|
||||
# "role": "system",
|
||||
# "content": """You are an iterative knowledge graph builder.
|
||||
# You are given the current state of the graph, and you must append the nodes and edges
|
||||
# to it Do not procide any duplcates and try to reuse nodes as much as possible.""",
|
||||
# },
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": f"""Extract any new nodes and edges from the following:
|
||||
# # Part {i}/{num_iterations} of the input:
|
||||
|
||||
# {inp}""",
|
||||
# },
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": f"""Here is the current state of the graph:
|
||||
# {cur_state.model_dump_json(indent=2)}""",
|
||||
# },
|
||||
# ],
|
||||
# response_model=KnowledgeGraph,
|
||||
# ) # type: ignore
|
||||
|
||||
# # Update the current state
|
||||
# cur_state = cur_state.update(new_updates)
|
||||
# cur_state.draw(prefix=f"iteration_{i}")
|
||||
# return cur_state
|
||||
|
||||
|
||||
|
||||
# Bonus: Generate with a chat history.
|
||||
# XXX: This illustrates the need for a dedicated chat type in ell.
|
||||
# @ell.complex(model="gpt-4o-2024-08-06", response_format=KnowledgeGraph)
|
||||
# def update_knowledge_graph_with_chat_history(cur_state: KnowledgeGraph, inp: str, i: int, num_iterations: int, chat_history):
|
||||
# return [
|
||||
# ell.system("""You are an iterative knowledge graph builder.
|
||||
# You are given the current state of the graph, and you must append the nodes and edges
|
||||
# to it Do not procide any duplcates and try to reuse nodes as much as possible."""),
|
||||
# *chat_history,
|
||||
# ell.user(f"""Extract any new nodes and edges from the following:
|
||||
# # Part {i}/{num_iterations} of the input:
|
||||
|
||||
# {inp}"""),
|
||||
# ell.user(f"""Here is the current state of the graph:
|
||||
# {cur_state.model_dump_json(indent=2)}""")
|
||||
# ]
|
||||
|
||||
# def generate_graph_with_chat_history(input: List[str]) -> KnowledgeGraph:
|
||||
# chat_history = []
|
||||
# cur_state = KnowledgeGraph()
|
||||
# num_iterations = len(input)
|
||||
# for i, inp in enumerate(input):
|
||||
# new_updates = update_knowledge_graph_with_chat_history(cur_state, inp, i, num_iterations, chat_history)
|
||||
# cur_state = cur_state.update(new_updates.parsed[0])
|
||||
# chat_history.append(new_updates)
|
||||
# cur_state.draw(prefix=f"iteration_{i}")
|
||||
# return cur_state
|
||||
@@ -6,8 +6,6 @@ import numpy as np
|
||||
|
||||
from ell.stores.sql import SQLiteStore
|
||||
|
||||
import openai
|
||||
|
||||
@ell.simple(model="gpt-4o-mini")
|
||||
def come_up_with_a_premise_for_a_joke_about(topic : str):
|
||||
"""You are an incredibly funny comedian. Come up with a premise for a joke about topic"""
|
||||
|
||||
@@ -11,6 +11,7 @@ from ell.types.message import system, user, assistant, Message, ContentBlock
|
||||
from ell.__version__ import __version__
|
||||
|
||||
# Import all models
|
||||
import ell.providers
|
||||
import ell.models
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from functools import wraps
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from typing import Dict, Any, Optional, Tuple, Union, Type
|
||||
import openai
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
import threading
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from ell.store import Store
|
||||
from ell.provider import Provider
|
||||
|
||||
_config_logger = logging.getLogger(__name__)
|
||||
class Config(BaseModel):
|
||||
@@ -20,13 +21,14 @@ class Config(BaseModel):
|
||||
default_lm_params: Dict[str, Any] = Field(default_factory=dict, description="Default parameters for language models.")
|
||||
default_system_prompt: str = Field(default="You are a helpful AI assistant.", description="The default system prompt used for AI interactions.")
|
||||
default_client: Optional[openai.Client] = Field(default=None, description="The default OpenAI client used when a specific model client is not found.")
|
||||
providers: Dict[Type, Type[Provider]] = Field(default_factory=dict, description="A dictionary mapping client types to provider classes.")
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._lock = threading.Lock()
|
||||
self._local = threading.local()
|
||||
|
||||
def register_model(self, model_name: str, client: openai.Client) -> None:
|
||||
def register_model(self, model_name: str, client: Any) -> None:
|
||||
"""
|
||||
Register an OpenAI client for a specific model name.
|
||||
|
||||
@@ -49,7 +51,7 @@ class Config(BaseModel):
|
||||
return self.store is not None
|
||||
|
||||
@contextmanager
|
||||
def model_registry_override(self, overrides: Dict[str, openai.Client]):
|
||||
def model_registry_override(self, overrides: Dict[str, Any]):
|
||||
"""
|
||||
Temporarily override the model registry with new client mappings.
|
||||
|
||||
@@ -70,7 +72,7 @@ class Config(BaseModel):
|
||||
finally:
|
||||
self._local.stack.pop()
|
||||
|
||||
def get_client_for(self, model_name: str) -> Optional[openai.Client]:
|
||||
def get_client_for(self, model_name: str) -> Tuple[Optional[Any], bool]:
|
||||
"""
|
||||
Get the OpenAI client for a specific model name.
|
||||
|
||||
@@ -154,6 +156,26 @@ class Config(BaseModel):
|
||||
"""
|
||||
self.default_client = client
|
||||
|
||||
def register_provider(self, provider_class: Type[Provider]) -> None:
|
||||
"""
|
||||
Register a provider class for a specific client type.
|
||||
|
||||
:param provider_class: The provider class to register.
|
||||
:type provider_class: Type[AbstractProvider]
|
||||
"""
|
||||
with self._lock:
|
||||
self.providers[provider_class.get_client_type()] = provider_class
|
||||
|
||||
def get_provider_for(self, client: Any) -> Optional[Type[Provider]]:
|
||||
"""
|
||||
Get the provider class for a specific client instance.
|
||||
|
||||
:param client: The client instance to get the provider for.
|
||||
:type client: Any
|
||||
:return: The provider class for the specified client, or None if not found.
|
||||
:rtype: Optional[Type[AbstractProvider]]
|
||||
"""
|
||||
return next((provider for client_type, provider in self.providers.items() if isinstance(client, client_type)), None)
|
||||
|
||||
# Singleton instance
|
||||
config = Config()
|
||||
@@ -218,3 +240,6 @@ def set_default_system_prompt(*args, **kwargs) -> None:
|
||||
return config.set_default_system_prompt(*args, **kwargs)
|
||||
|
||||
# You can add more helper functions here if needed
|
||||
@wraps(config.register_provider)
|
||||
def register_provider(*args, **kwargs) -> None:
|
||||
return config.register_provider(*args, **kwargs)
|
||||
@@ -9,12 +9,10 @@ from ell.util.api import call
|
||||
from ell.util.verbosity import compute_color, model_usage_logger_pre
|
||||
|
||||
|
||||
import openai
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Optional, List, Callable, Union
|
||||
|
||||
def complex(model: str, client: Optional[openai.Client] = None, exempt_from_tracking=False, tools: Optional[List[Callable]] = None, post_callback: Optional[Callable] = None, **api_params):
|
||||
def complex(model: str, client: Optional[Any] = None, exempt_from_tracking=False, tools: Optional[List[Callable]] = None, post_callback: Optional[Callable] = None, **api_params):
|
||||
"""
|
||||
A sophisticated language model programming decorator for complex LLM interactions.
|
||||
|
||||
@@ -228,7 +226,7 @@ def complex(model: str, client: Optional[openai.Client] = None, exempt_from_trac
|
||||
def model_call(
|
||||
*fn_args,
|
||||
_invocation_origin : str = None,
|
||||
client: Optional[openai.Client] = None,
|
||||
client: Optional[Any] = None,
|
||||
lm_params: Optional[LMPParams] = {},
|
||||
invocation_api_params=False,
|
||||
**fn_kwargs,
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import openai
|
||||
from ell.lmp.complex import complex
|
||||
|
||||
|
||||
def simple(model: str, client: Optional[openai.Client] = None, exempt_from_tracking=False, **api_params):
|
||||
def simple(model: str, client: Optional[Any] = None, exempt_from_tracking=False, **api_params):
|
||||
"""
|
||||
The fundamental unit of language model programming in ell.
|
||||
|
||||
|
||||
56
src/ell/provider.py
Normal file
56
src/ell/provider.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from ell.types import Message, ContentBlock, ToolCall
|
||||
from ell.types._lstr import _lstr
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from ell.types.message import LMP
|
||||
|
||||
|
||||
@dataclass
|
||||
class APICallResult:
|
||||
response: Any
|
||||
actual_streaming: bool
|
||||
actual_n: int
|
||||
final_call_params: Dict[str, Any]
|
||||
|
||||
|
||||
class Provider(ABC):
|
||||
"""
|
||||
Abstract base class for all providers. Providers are API interfaces to language models, not necessarily API providers.
|
||||
For example, the OpenAI provider is an API interface to OpenAI's API but also to Ollama and Azure OpenAI.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def call_model(
|
||||
cls,
|
||||
client: Any,
|
||||
model: str,
|
||||
messages: List[Any],
|
||||
api_params: Dict[str, Any],
|
||||
tools: Optional[list[LMP]] = None,
|
||||
) -> APICallResult:
|
||||
"""Make the API call to the language model and return the result along with actual streaming, n values, and final call parameters."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def process_response(
|
||||
cls, call_result: APICallResult, _invocation_origin: str, logger: Optional[Any] = None, tools: Optional[List[LMP]] = None,
|
||||
) -> Tuple[List[Message], Dict[str, Any]]:
|
||||
"""Process the API response and convert it to ell format."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def supports_streaming(cls) -> bool:
|
||||
"""Check if the provider supports streaming."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_client_type(cls) -> Type:
|
||||
"""Return the type of client this provider supports."""
|
||||
pass
|
||||
@@ -0,0 +1,9 @@
|
||||
import ell.providers.openai
|
||||
# import ell.providers.anthropic
|
||||
# import ell.providers.groq
|
||||
# import ell.providers.mistral
|
||||
# import ell.providers.cohere
|
||||
# import ell.providers.gemini
|
||||
# import ell.providers.elevenlabs
|
||||
# import ell.providers.replicate
|
||||
# import ell.providers.huggingface
|
||||
239
src/ell/providers/openai.py
Normal file
239
src/ell/providers/openai.py
Normal file
@@ -0,0 +1,239 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from ell.provider import APICallResult, Provider
|
||||
from ell.types import Message, ContentBlock, ToolCall
|
||||
from ell.types._lstr import _lstr
|
||||
import json
|
||||
from ell.configurator import config, register_provider
|
||||
from ell.types.message import LMP
|
||||
from ell.util.serialization import serialize_image
|
||||
|
||||
try:
|
||||
import openai
|
||||
class OpenAIProvider(Provider):
|
||||
@staticmethod
|
||||
def content_block_to_openai_format(content_block: ContentBlock) -> Dict[str, Any]:
|
||||
if content_block.image:
|
||||
base64_image = serialize_image(content_block.image)
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": base64_image
|
||||
}
|
||||
}
|
||||
elif content_block.text:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": content_block.text
|
||||
}
|
||||
elif content_block.parsed:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": content_block.parsed.model_dump_json()
|
||||
}
|
||||
# Tool calls handled in message_to_openai_format.
|
||||
#XXX: Feel free to refactor this.
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def message_to_openai_format(message: Message) -> Dict[str, Any]:
|
||||
openai_message = {
|
||||
"role": "tool" if message.tool_results else message.role,
|
||||
"content": list(filter(None, [
|
||||
OpenAIProvider.content_block_to_openai_format(c) for c in message.content
|
||||
]))
|
||||
}
|
||||
if message.tool_calls:
|
||||
try:
|
||||
openai_message["tool_calls"] = [
|
||||
{
|
||||
"id": tool_call.tool_call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.tool.__name__,
|
||||
"arguments": json.dumps(tool_call.params.model_dump())
|
||||
}
|
||||
} for tool_call in message.tool_calls
|
||||
]
|
||||
except TypeError as e:
|
||||
print(f"Error serializing tool calls: {e}. Did you fully type your @ell.tool decorated functions?")
|
||||
raise
|
||||
openai_message["content"] = None # Set content to null when there are tool calls
|
||||
|
||||
if message.tool_results:
|
||||
openai_message["tool_call_id"] = message.tool_results[0].tool_call_id
|
||||
openai_message["content"] = message.tool_results[0].result[0].text
|
||||
assert len(message.tool_results[0].result) == 1, "Tool result should only have one content block"
|
||||
assert message.tool_results[0].result[0].type == "text", "Tool result should only have one text content block"
|
||||
return openai_message
|
||||
|
||||
@classmethod
|
||||
def call_model(
|
||||
cls,
|
||||
client: Any,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
api_params: Dict[str, Any],
|
||||
tools: Optional[list[LMP]] = None,
|
||||
) -> APICallResult:
|
||||
final_call_params = api_params.copy()
|
||||
openai_messages = [cls.message_to_openai_format(message) for message in messages]
|
||||
|
||||
actual_n = api_params.get("n", 1)
|
||||
final_call_params["model"] = model
|
||||
final_call_params["messages"] = openai_messages
|
||||
|
||||
if final_call_params.get("response_format"):
|
||||
final_call_params.pop("stream", None)
|
||||
final_call_params.pop("stream_options", None)
|
||||
response = client.beta.chat.completions.parse(**final_call_params)
|
||||
else:
|
||||
# Tools not workign with structured API
|
||||
if tools:
|
||||
final_call_params["tool_choice"] = "auto"
|
||||
final_call_params["tools"] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.__name__,
|
||||
"description": tool.__doc__,
|
||||
"parameters": tool.__ell_params_model__.model_json_schema(),
|
||||
},
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
final_call_params.pop("stream", None)
|
||||
final_call_params.pop("stream_options", None)
|
||||
else:
|
||||
final_call_params["stream_options"] = {"include_usage": True}
|
||||
final_call_params["stream"] = True
|
||||
|
||||
response = client.chat.completions.create(**final_call_params)
|
||||
|
||||
|
||||
return APICallResult(
|
||||
response=response,
|
||||
actual_streaming=isinstance(response, openai.Stream),
|
||||
actual_n=actual_n,
|
||||
final_call_params=final_call_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def process_response(
|
||||
cls, call_result: APICallResult, _invocation_origin: str, logger : Optional[Any] = None, tools: Optional[List[LMP]] = None,
|
||||
) -> Tuple[List[Message], Dict[str, Any]]:
|
||||
choices_progress = defaultdict(list)
|
||||
api_params = call_result.final_call_params
|
||||
metadata = {}
|
||||
#XXX: Remove logger and refactor this API
|
||||
|
||||
if not call_result.actual_streaming:
|
||||
response = [call_result.response]
|
||||
else:
|
||||
response = call_result.response
|
||||
|
||||
for chunk in response:
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
metadata = chunk.to_dict()
|
||||
if call_result.actual_streaming:
|
||||
continue
|
||||
|
||||
|
||||
for choice in chunk.choices:
|
||||
choices_progress[choice.index].append(choice)
|
||||
if choice.index == 0 and logger:
|
||||
# print(choice, streaming)
|
||||
logger(choice.delta.content if call_result.actual_streaming else
|
||||
choice.message.content or getattr(choice.message, "refusal", ""), is_refusal=getattr(choice.message, "refusal", False) if not call_result.actual_streaming else False)
|
||||
|
||||
|
||||
|
||||
tracked_results = []
|
||||
for _, choice_deltas in sorted(choices_progress.items(), key=lambda x: x[0]):
|
||||
content = []
|
||||
|
||||
if call_result.actual_streaming:
|
||||
text_content = "".join(
|
||||
(choice.delta.content or "" for choice in choice_deltas)
|
||||
)
|
||||
if text_content:
|
||||
content.append(
|
||||
ContentBlock(
|
||||
text=_lstr(
|
||||
content=text_content, _origin_trace=_invocation_origin
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
choice = choice_deltas[0].message
|
||||
if choice.refusal:
|
||||
raise ValueError(choice.refusal)
|
||||
if api_params.get("response_format"):
|
||||
content.append(ContentBlock(parsed=choice.parsed))
|
||||
elif choice.content:
|
||||
content.append(
|
||||
ContentBlock(
|
||||
text=_lstr(
|
||||
content=choice.content, _origin_trace=_invocation_origin
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not call_result.actual_streaming and hasattr(choice, "tool_calls") and choice.tool_calls:
|
||||
assert tools is not None, "Tools not provided, yet tool calls in response. Did you manually specify a tool spec without using ell.tool?"
|
||||
for tool_call in choice.tool_calls:
|
||||
matching_tool = next(
|
||||
(
|
||||
tool
|
||||
for tool in tools
|
||||
if tool.__name__ == tool_call.function.name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if matching_tool:
|
||||
params = matching_tool.__ell_params_model__(
|
||||
**json.loads(tool_call.function.arguments)
|
||||
)
|
||||
content.append(
|
||||
ContentBlock(
|
||||
tool_call=ToolCall(
|
||||
tool=matching_tool,
|
||||
tool_call_id=_lstr(
|
||||
tool_call.id, _origin_trace=_invocation_origin
|
||||
),
|
||||
params=params,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
tracked_results.append(
|
||||
Message(
|
||||
role=(
|
||||
choice.role
|
||||
if not call_result.actual_streaming
|
||||
else choice_deltas[0].delta.role
|
||||
),
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
|
||||
return tracked_results, metadata
|
||||
|
||||
@classmethod
|
||||
def supports_streaming(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_client_type(cls) -> Type:
|
||||
return openai.Client
|
||||
|
||||
@classmethod
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls.register()
|
||||
|
||||
register_provider(OpenAIProvider)
|
||||
except ImportError:
|
||||
pass
|
||||
@@ -48,7 +48,7 @@ class ContentBlock(BaseModel):
|
||||
image: Optional[Union[PILImage.Image, str, np.ndarray]] = Field(default=None)
|
||||
audio: Optional[Union[np.ndarray, List[float]]] = Field(default=None)
|
||||
tool_call: Optional[ToolCall] = Field(default=None)
|
||||
parsed: Optional[Union[Type[BaseModel], BaseModel]] = Field(default=None)
|
||||
parsed: Optional[BaseModel] = Field(default=None)
|
||||
tool_result: Optional[ToolResult] = Field(default=None)
|
||||
|
||||
@model_validator(mode='after')
|
||||
@@ -174,29 +174,6 @@ class ContentBlock(BaseModel):
|
||||
return serialize_image(image)
|
||||
|
||||
|
||||
def to_openai_content_block(self):
|
||||
if self.image:
|
||||
base64_image = self.serialize_image(self.image, None)
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": base64_image
|
||||
}
|
||||
}
|
||||
elif self.text:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": self.text
|
||||
}
|
||||
elif self.parsed:
|
||||
return {
|
||||
"type": "text",
|
||||
"json": self.parsed.model_dump_json()
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def coerce_content_list(content: Union[str, List[ContentBlock], List[Union[str, ContentBlock, ToolCall, ToolResult, BaseModel]]] = None, **content_block_kwargs) -> List[ContentBlock]:
|
||||
if not content:
|
||||
content = [ContentBlock(**content_block_kwargs)]
|
||||
@@ -309,36 +286,6 @@ class Message(BaseModel):
|
||||
content = [c.tool_call.call_and_collect_as_message_block() for c in self.content if c.tool_call]
|
||||
return Message(role="user", content=content)
|
||||
|
||||
def to_openai_message(self) -> Dict[str, Any]:
|
||||
|
||||
message = {
|
||||
"role": "tool" if self.tool_results else self.role,
|
||||
"content": list(filter(None, [
|
||||
c.to_openai_content_block() for c in self.content
|
||||
]))
|
||||
}
|
||||
if self.tool_calls:
|
||||
message["tool_calls"] = [
|
||||
{
|
||||
"id": tool_call.tool_call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.tool.__name__,
|
||||
"arguments": json.dumps(tool_call.params.model_dump())
|
||||
}
|
||||
} for tool_call in self.tool_calls
|
||||
]
|
||||
message["content"] = None # Set content to null when there are tool calls
|
||||
|
||||
if self.tool_results:
|
||||
message["tool_call_id"] = self.tool_results[0].tool_call_id
|
||||
# message["name"] = self.tool_results[0].tool_call_id.split('-')[0] # Assuming the tool name is the first part of the tool_call_id
|
||||
message["content"] = self.tool_results[0].result[0].text
|
||||
# Let';s assert no other type of content block in the tool result
|
||||
assert len(self.tool_results[0].result) == 1, "Tool result should only have one content block"
|
||||
assert self.tool_results[0].result[0].type == "text", "Tool result should only have one text content block"
|
||||
return message
|
||||
|
||||
# HELPERS
|
||||
def system(content: Union[str, List[ContentBlock]]) -> Message:
|
||||
"""
|
||||
|
||||
@@ -139,7 +139,7 @@ class InvocationContentsBase(SQLModel):
|
||||
]
|
||||
|
||||
total_size = sum(
|
||||
len(json.dumps(field, default=(lambda x: x.model_dump_json() if isinstance(x, BaseModel) else str(x))).encode('utf-8')) for field in json_fields if field is not None
|
||||
len(json.dumps(field, default=(lambda x: json.dumps(x.model_dump(), default=str) if isinstance(x, BaseModel) else str(x))).encode('utf-8')) for field in json_fields if field is not None
|
||||
)
|
||||
# print("total_size", total_size)
|
||||
|
||||
|
||||
@@ -17,16 +17,16 @@ To fix this:
|
||||
import ell
|
||||
import openai
|
||||
|
||||
ell.lm(model, client=openai.Client(api_key=my_key))
|
||||
ell.simple(model, client=openai.Client(api_key=my_key))
|
||||
def {name}(...):
|
||||
...
|
||||
```
|
||||
* Or explicitly specify the client when the calling the LMP:
|
||||
|
||||
```
|
||||
ell.lm(model, client=openai.Client(api_key=my_key))(...)
|
||||
ell.simple(model, client=openai.Client(api_key=my_key))(...)
|
||||
```
|
||||
""" if long else " at time of definition. Can be okay if custom client specified later! <TODO: add link to docs> ") + f"{Style.RESET_ALL}"
|
||||
""" if long else " at time of definition. Can be okay if custom client specified later! https://docs.ell.so/core_concepts/models_and_api_clients.html ") + f"{Style.RESET_ALL}"
|
||||
|
||||
|
||||
def _warnings(model, fn, default_client_from_decorator):
|
||||
@@ -40,14 +40,14 @@ def _warnings(model, fn, default_client_from_decorator):
|
||||
* If this is a mistake either specify a client explicitly in the decorator:
|
||||
```python
|
||||
import ell
|
||||
ell.lm(model, client=my_client)
|
||||
ell.simple(model, client=my_client)
|
||||
def {fn.__name__}(...):
|
||||
...
|
||||
```
|
||||
or explicitly specify the client when the calling the LMP:
|
||||
|
||||
```python
|
||||
ell.lm(model, client=my_client)(...)
|
||||
ell.simple(model, client=my_client)(...)
|
||||
```
|
||||
{Style.RESET_ALL}""")
|
||||
elif (client_to_use := config.registry[model]) is None or not client_to_use.api_key:
|
||||
|
||||
@@ -1,170 +1,63 @@
|
||||
from functools import partial
|
||||
import json
|
||||
|
||||
# import anthropic
|
||||
from ell.configurator import config
|
||||
import openai
|
||||
|
||||
from collections import defaultdict
|
||||
from ell.types._lstr import _lstr
|
||||
from ell.types import Message, ContentBlock, ToolCall
|
||||
|
||||
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type
|
||||
from ell.types.message import LMP, LMPParams, MessageOrDict
|
||||
|
||||
from ell.util.verbosity import model_usage_logger_post_end, model_usage_logger_post_intermediate, model_usage_logger_post_start
|
||||
from ell.util._warnings import _no_api_key_warning
|
||||
from ell.provider import APICallResult, Provider
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def process_messages_for_client(messages: list[Message], client: Any):
|
||||
if isinstance(client, openai.Client):
|
||||
return [
|
||||
message.to_openai_message()
|
||||
for message in messages]
|
||||
# elif isinstance(client, anthropic.Anthropic):
|
||||
# return messages
|
||||
# XXX: or some such.
|
||||
|
||||
|
||||
def call(
|
||||
*,
|
||||
model: str,
|
||||
messages: list[Message],
|
||||
api_params: Dict[str, Any],
|
||||
tools: Optional[list[LMP]] = None,
|
||||
client: Optional[openai.Client] = None,
|
||||
client: Optional[Any] = None,
|
||||
_invocation_origin: str,
|
||||
_exempt_from_tracking: bool,
|
||||
_logging_color=None,
|
||||
_name: str = None,
|
||||
) -> Tuple[Union[_lstr, Iterable[_lstr]], Optional[Dict[str, Any]]]:
|
||||
_logging_color: Optional[str] = None,
|
||||
_name: Optional[str] = None,
|
||||
) -> Tuple[Union[Message, List[Message]], Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
Helper function to run the language model with the provided messages and parameters.
|
||||
"""
|
||||
# Todo: Decide if the client specified via the context amanger default registry is the shit or if the cliennt specified via lmp invocation args are the hing.
|
||||
if not client:
|
||||
client, was_fallback = config.get_client_for(model)
|
||||
if not client and not was_fallback:
|
||||
# Someone registered you as None and you're trying to use this shit
|
||||
raise RuntimeError(_no_api_key_warning(model, _name, '', long=True, error=True))
|
||||
|
||||
metadata = dict()
|
||||
if client is None:
|
||||
raise ValueError(f"No client found for model '{model}'. Ensure the model is registered using 'register_model' in 'config.py' or specify a client directly using the 'client' argument in the decorator or function call.")
|
||||
|
||||
if not client.api_key:
|
||||
raise RuntimeError(_no_api_key_warning(model, _name, client, long=True, error=True))
|
||||
|
||||
# todo: add suupport for streaming apis that dont give a final usage in the api
|
||||
# print(api_params)
|
||||
if api_params.get("response_format", False):
|
||||
model_call = client.beta.chat.completions.parse
|
||||
api_params.pop("stream", None)
|
||||
api_params.pop("stream_options", None)
|
||||
elif tools:
|
||||
model_call = client.chat.completions.create
|
||||
api_params["tools"] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.__name__,
|
||||
"description": tool.__doc__,
|
||||
"parameters": tool.__ell_params_model__.model_json_schema()
|
||||
}
|
||||
} for tool in tools
|
||||
]
|
||||
api_params["tool_choice"] = "auto"
|
||||
api_params.pop("stream", None)
|
||||
api_params.pop("stream_options", None)
|
||||
else:
|
||||
model_call = client.chat.completions.create
|
||||
api_params["stream"] = True
|
||||
api_params["stream_options"] = {"include_usage": True}
|
||||
provider_class: Type[Provider] = config.get_provider_for(client)
|
||||
|
||||
client_safe_messages_messages = process_messages_for_client(messages, client)
|
||||
# print(api_params)
|
||||
model_result = model_call(
|
||||
model=model, messages=client_safe_messages_messages, **api_params
|
||||
)
|
||||
streaming = api_params.get("stream", False)
|
||||
if not streaming:
|
||||
model_result = [model_result]
|
||||
|
||||
choices_progress = defaultdict(list)
|
||||
n = api_params.get("n", 1)
|
||||
# XXX: Could actually delete htis
|
||||
call_result = provider_class.call_model(client, model, messages, api_params, tools)
|
||||
|
||||
if config.verbose and not _exempt_from_tracking:
|
||||
model_usage_logger_post_start(_logging_color, n)
|
||||
model_usage_logger_post_start(_logging_color, call_result.actual_n)
|
||||
|
||||
with model_usage_logger_post_intermediate(_logging_color, n) as _logger:
|
||||
for chunk in model_result:
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
# Todo: is this a good decision.
|
||||
metadata = chunk.to_dict()
|
||||
with model_usage_logger_post_intermediate(_logging_color, call_result.actual_n) as _logger:
|
||||
|
||||
if streaming:
|
||||
continue
|
||||
tracked_results, metadata = provider_class.process_response(call_result, _invocation_origin, _logger if config.verbose and not _exempt_from_tracking else None, tools)
|
||||
|
||||
for choice in chunk.choices:
|
||||
choices_progress[choice.index].append(choice)
|
||||
if config.verbose and choice.index == 0 and not _exempt_from_tracking:
|
||||
# print(choice, streaming)
|
||||
_logger(choice.delta.content if streaming else
|
||||
choice.message.content or getattr(choice.message, "refusal", ""), is_refusal=getattr(choice.message, "refusal", False) if not streaming else False)
|
||||
|
||||
if config.verbose and not _exempt_from_tracking:
|
||||
model_usage_logger_post_end()
|
||||
n_choices = len(choices_progress)
|
||||
|
||||
# coerce the streaming into a final message type
|
||||
tracked_results = []
|
||||
for _, choice_deltas in sorted(choices_progress.items(), key=lambda x: x[0]):
|
||||
content = []
|
||||
|
||||
# Handle text content
|
||||
if streaming:
|
||||
text_content = "".join((choice.delta.content or "" for choice in choice_deltas))
|
||||
if text_content:
|
||||
content.append(ContentBlock(
|
||||
text=_lstr(content=text_content, _origin_trace=_invocation_origin)
|
||||
))
|
||||
else:
|
||||
choice = choice_deltas[0].message
|
||||
if choice.refusal:
|
||||
raise ValueError(choice.refusal)
|
||||
# XXX: is this the best practice? try catch a parser?
|
||||
if api_params.get("response_format", False):
|
||||
content.append(ContentBlock(
|
||||
parsed=choice.parsed
|
||||
))
|
||||
elif choice.content:
|
||||
content.append(ContentBlock(
|
||||
text=_lstr(content=choice.content, _origin_trace=_invocation_origin)
|
||||
))
|
||||
|
||||
# Handle tool calls
|
||||
if not streaming and hasattr(choice, 'tool_calls'):
|
||||
for tool_call in choice.tool_calls or []:
|
||||
matching_tool = None
|
||||
for tool in tools:
|
||||
if tool.__name__ == tool_call.function.name:
|
||||
matching_tool = tool
|
||||
break
|
||||
|
||||
if matching_tool:
|
||||
params = matching_tool.__ell_params_model__(**json.loads(tool_call.function.arguments))
|
||||
content.append(ContentBlock(
|
||||
tool_call=ToolCall(tool=matching_tool, tool_call_id=_lstr(tool_call.id, _origin_trace=_invocation_origin), params=params)
|
||||
))
|
||||
|
||||
tracked_results.append(Message(
|
||||
role=choice.role if not streaming else choice_deltas[0].delta.role,
|
||||
content=content
|
||||
))
|
||||
|
||||
api_params = dict(model=model, messages=client_safe_messages_messages, api_params=api_params)
|
||||
|
||||
return tracked_results[0] if n_choices == 1 else tracked_results, api_params, metadata
|
||||
return (tracked_results[0] if len(tracked_results) == 1 else tracked_results), call_result.final_call_params, metadata
|
||||
@@ -43,7 +43,7 @@ def check_version_and_log():
|
||||
import ell
|
||||
|
||||
try:
|
||||
response = requests.get("https://docs.ell.so/_static/ell_version.txt", timeout=0.5)
|
||||
response = requests.get("https://docs.ell.so/_static/ell_version.txt", timeout=0.1)
|
||||
if response.status_code == 200:
|
||||
latest_version = response.text.strip()
|
||||
if latest_version != ell.__version__:
|
||||
|
||||
129
tests/test_openai_provider.py
Normal file
129
tests/test_openai_provider.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from ell.provider import APICallResult
|
||||
from ell.providers.openai import OpenAIProvider
|
||||
from ell.types import Message, ContentBlock, ToolCall
|
||||
from ell.types.message import LMP, ToolResult
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
import ell
|
||||
class DummyParams(BaseModel):
|
||||
param1: str
|
||||
param2: int
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client():
|
||||
return Mock()
|
||||
import openai
|
||||
def test_content_block_to_openai_format():
|
||||
# Test text content
|
||||
text_block = ContentBlock(text="Hello, world!")
|
||||
assert OpenAIProvider.content_block_to_openai_format(text_block) == {
|
||||
"type": "text",
|
||||
"text": "Hello, world!"
|
||||
}
|
||||
|
||||
# Test parsed content
|
||||
class DummyParsed(BaseModel):
|
||||
field: str
|
||||
parsed_block = ContentBlock(parsed=DummyParsed(field="value"))
|
||||
|
||||
|
||||
res = OpenAIProvider.content_block_to_openai_format(parsed_block)
|
||||
assert res["type"] == "text"
|
||||
assert (res["text"]) == '{"field":"value"}'
|
||||
|
||||
|
||||
# Test image content (mocked)
|
||||
with patch('ell.providers.openai.serialize_image', return_value="base64_image_data"):
|
||||
# Test random image content
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
# Generate a random image
|
||||
random_image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
|
||||
pil_image = Image.fromarray(random_image)
|
||||
|
||||
with patch('ell.providers.openai.serialize_image', return_value="random_base64_image_data"):
|
||||
random_image_block = ContentBlock(image=pil_image)
|
||||
assert OpenAIProvider.content_block_to_openai_format(random_image_block) == {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "random_base64_image_data"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_message_to_openai_format():
|
||||
# Test simple message
|
||||
simple_message = Message(role="user", content=[ContentBlock(text="Hello")])
|
||||
assert OpenAIProvider.message_to_openai_format(simple_message) == {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello"}]
|
||||
}
|
||||
|
||||
# Test message with tool calls
|
||||
def dummy_tool(param1: str, param2: int): pass
|
||||
tool_call = ToolCall(tool=dummy_tool, tool_call_id="123", params=DummyParams(param1="test", param2=42))
|
||||
tool_message = Message(role="assistant", content=[tool_call])
|
||||
formatted = OpenAIProvider.message_to_openai_format(tool_message)
|
||||
assert formatted["role"] == "assistant"
|
||||
assert formatted["content"] is None
|
||||
assert len(formatted["tool_calls"]) == 1
|
||||
assert formatted["tool_calls"][0]["id"] == "123"
|
||||
assert formatted["tool_calls"][0]["function"]["name"] == "dummy_tool"
|
||||
assert json.loads(formatted["tool_calls"][0]["function"]["arguments"]) == {"param1": "test", "param2": 42}
|
||||
|
||||
# Test message with tool results
|
||||
tool_result_message = Message(
|
||||
role="user",
|
||||
content=[ToolResult(tool_call_id="123", result=[ContentBlock(text="Tool output")])],
|
||||
)
|
||||
formatted = OpenAIProvider.message_to_openai_format(tool_result_message)
|
||||
assert formatted["role"] == "tool"
|
||||
assert formatted["tool_call_id"] == "123"
|
||||
assert formatted["content"] == "Tool output"
|
||||
|
||||
def test_call_model(mock_openai_client):
|
||||
messages = [Message(role="user", content=[ContentBlock(text="Hello")], refusal=None)]
|
||||
api_params = {"temperature": 0.7}
|
||||
|
||||
# Mock the client's chat.completions.create method
|
||||
mock_openai_client.chat.completions.create.return_value = Mock(choices=[Mock(message=Mock(content="Response", refusal=None))])
|
||||
|
||||
@ell.tool()
|
||||
def dummy_tool(param1: str, param2: int): pass
|
||||
|
||||
result = OpenAIProvider.call_model(mock_openai_client, "gpt-3.5-turbo", messages, api_params, tools=[dummy_tool])
|
||||
|
||||
assert isinstance(result, APICallResult)
|
||||
assert not "stream" in result.final_call_params
|
||||
assert not result.actual_streaming
|
||||
assert result.actual_n == 1
|
||||
assert "messages" in result.final_call_params
|
||||
assert result.final_call_params["model"] == "gpt-3.5-turbo"
|
||||
|
||||
|
||||
def test_process_response():
|
||||
# Mock APICallResult
|
||||
mock_response = Mock(
|
||||
choices=[Mock(message=Mock(role="assistant", content="Hello, world!", refusal=None, tool_calls=None))]
|
||||
)
|
||||
call_result = APICallResult(
|
||||
response=mock_response,
|
||||
actual_streaming=False,
|
||||
actual_n=1,
|
||||
final_call_params={}
|
||||
)
|
||||
|
||||
processed_messages, metadata = OpenAIProvider.process_response(call_result, "test_origin")
|
||||
|
||||
assert len(processed_messages) == 1
|
||||
assert processed_messages[0].role == "assistant"
|
||||
assert len(processed_messages[0].content) == 1
|
||||
assert processed_messages[0].content[0].text == "Hello, world!"
|
||||
|
||||
def test_supports_streaming():
|
||||
assert OpenAIProvider.supports_streaming() == True
|
||||
|
||||
# Add more tests as needed for other methods and edge cases
|
||||
Reference in New Issue
Block a user