From 16c6a814c89849f7ce1fd23415eb996b96ba76ac Mon Sep 17 00:00:00 2001 From: William Guss Date: Wed, 11 Sep 2024 13:17:49 -0700 Subject: [PATCH] shitty provider implementation to work from --- examples/future/knowledge_graph.py | 153 ++++++++++++++++++ examples/joke.py | 2 - src/ell/__init__.py | 1 + src/ell/configurator.py | 33 +++- src/ell/lmp/complex.py | 6 +- src/ell/lmp/simple.py | 5 +- src/ell/provider.py | 56 +++++++ src/ell/providers/__init__.py | 9 ++ src/ell/providers/openai.py | 239 +++++++++++++++++++++++++++++ src/ell/types/message.py | 55 +------ src/ell/types/studio.py | 2 +- src/ell/util/_warnings.py | 10 +- src/ell/util/api.py | 147 +++--------------- src/ell/util/verbosity.py | 2 +- tests/test_openai_provider.py | 129 ++++++++++++++++ 15 files changed, 648 insertions(+), 201 deletions(-) create mode 100644 examples/future/knowledge_graph.py create mode 100644 src/ell/provider.py create mode 100644 src/ell/providers/openai.py create mode 100644 tests/test_openai_provider.py diff --git a/examples/future/knowledge_graph.py b/examples/future/knowledge_graph.py new file mode 100644 index 0000000..a22e112 --- /dev/null +++ b/examples/future/knowledge_graph.py @@ -0,0 +1,153 @@ +"""" +Example originally from Instructor docs https://python.useinstructor.com/examples/knowledge_graph/ +All rights reserved to the original author. +""" + +from graphviz import Digraph +from pydantic import BaseModel, Field +from typing import List, Optional + + +class Node(BaseModel): + id: int + label: str + color: str + + +class Edge(BaseModel): + source: int + target: int + label: str + color: str = Field(description="The color of the edge. Defaults to black.") + + +class KnowledgeGraph(BaseModel): + nodes: Optional[List[Node]] = Field(..., default_factory=list) + edges: Optional[List[Edge]] = Field(..., default_factory=list) + + def update(self, other: "KnowledgeGraph") -> "KnowledgeGraph": + """Updates the current graph with the other graph, deduplicating nodes and edges.""" + # Create dictionaries to store unique nodes and edges + unique_nodes = {node.id: node for node in self.nodes} + unique_edges = {(edge.source, edge.target, edge.label): edge for edge in self.edges} + + # Update with nodes and edges from the other graph + for node in other.nodes: + unique_nodes[node.id] = node + for edge in other.edges: + unique_edges[(edge.source, edge.target, edge.label)] = edge + + return KnowledgeGraph( + nodes=list(unique_nodes.values()), + edges=list(unique_edges.values()), + ) + + def draw(self, prefix: str = None): + dot = Digraph(comment="Knowledge Graph") + + for node in self.nodes: + dot.node(str(node.id), node.label, color=node.color) + + for edge in self.edges: + dot.edge( + str(edge.source), str(edge.target), label=edge.label, color=edge.color + ) + dot.render(prefix, format="png", view=True) + + +import ell + +@ell.complex(model="gpt-4o-2024-08-06", response_format=KnowledgeGraph) +def update_knowledge_graph(cur_state: KnowledgeGraph, inp: str, i: int, num_iterations: int): + return [ + ell.system("""You are an iterative knowledge graph builder. + You are given the current state of the graph, and you must append the nodes and edges + to it Do not procide any duplcates and try to reuse nodes as much as possible."""), + ell.user(f"""Extract any new nodes and edges from the following: + # Part {i}/{num_iterations} of the input: + + {inp}"""), + ell.user(f"""Here is the current state of the graph: + {cur_state.model_dump_json(indent=2)}""") + ] + +def generate_graph(input: List[str]) -> KnowledgeGraph: + cur_state = KnowledgeGraph() + num_iterations = len(input) + for i, inp in enumerate(input): + new_updates = update_knowledge_graph(cur_state, inp, i, num_iterations).parsed[0] + cur_state = cur_state.update(new_updates) + cur_state.draw(prefix=f"iteration_{i}") + return cur_state + + + + +if __name__ == "__main__": + ell.init(verbose=True, store='./logdir', autocommit=True) + generate_graph(["This is a test", "This is another test", "This is a third test"]) + +# Compare to: Original instructor example. +# def generate_graph(input: List[str]) -> KnowledgeGraph: +# cur_state = KnowledgeGraph() +# num_iterations = len(input) +# for i, inp in enumerate(input): +# new_updates = client.chat.completions.create( +# model="gpt-3.5-turbo-16k", +# messages=[ +# { +# "role": "system", +# "content": """You are an iterative knowledge graph builder. +# You are given the current state of the graph, and you must append the nodes and edges +# to it Do not procide any duplcates and try to reuse nodes as much as possible.""", +# }, +# { +# "role": "user", +# "content": f"""Extract any new nodes and edges from the following: +# # Part {i}/{num_iterations} of the input: + +# {inp}""", +# }, +# { +# "role": "user", +# "content": f"""Here is the current state of the graph: +# {cur_state.model_dump_json(indent=2)}""", +# }, +# ], +# response_model=KnowledgeGraph, +# ) # type: ignore + +# # Update the current state +# cur_state = cur_state.update(new_updates) +# cur_state.draw(prefix=f"iteration_{i}") +# return cur_state + + + +# Bonus: Generate with a chat history. +# XXX: This illustrates the need for a dedicated chat type in ell. +# @ell.complex(model="gpt-4o-2024-08-06", response_format=KnowledgeGraph) +# def update_knowledge_graph_with_chat_history(cur_state: KnowledgeGraph, inp: str, i: int, num_iterations: int, chat_history): +# return [ +# ell.system("""You are an iterative knowledge graph builder. +# You are given the current state of the graph, and you must append the nodes and edges +# to it Do not procide any duplcates and try to reuse nodes as much as possible."""), +# *chat_history, +# ell.user(f"""Extract any new nodes and edges from the following: +# # Part {i}/{num_iterations} of the input: + +# {inp}"""), +# ell.user(f"""Here is the current state of the graph: +# {cur_state.model_dump_json(indent=2)}""") +# ] + +# def generate_graph_with_chat_history(input: List[str]) -> KnowledgeGraph: +# chat_history = [] +# cur_state = KnowledgeGraph() +# num_iterations = len(input) +# for i, inp in enumerate(input): +# new_updates = update_knowledge_graph_with_chat_history(cur_state, inp, i, num_iterations, chat_history) +# cur_state = cur_state.update(new_updates.parsed[0]) +# chat_history.append(new_updates) +# cur_state.draw(prefix=f"iteration_{i}") +# return cur_state \ No newline at end of file diff --git a/examples/joke.py b/examples/joke.py index 57fb7f1..303ebe3 100644 --- a/examples/joke.py +++ b/examples/joke.py @@ -6,8 +6,6 @@ import numpy as np from ell.stores.sql import SQLiteStore -import openai - @ell.simple(model="gpt-4o-mini") def come_up_with_a_premise_for_a_joke_about(topic : str): """You are an incredibly funny comedian. Come up with a premise for a joke about topic""" diff --git a/src/ell/__init__.py b/src/ell/__init__.py index d3549cb..e7bd022 100644 --- a/src/ell/__init__.py +++ b/src/ell/__init__.py @@ -11,6 +11,7 @@ from ell.types.message import system, user, assistant, Message, ContentBlock from ell.__version__ import __version__ # Import all models +import ell.providers import ell.models diff --git a/src/ell/configurator.py b/src/ell/configurator.py index a247379..6e52be7 100644 --- a/src/ell/configurator.py +++ b/src/ell/configurator.py @@ -1,11 +1,12 @@ from functools import wraps -from typing import Dict, Any, Optional, Union +from typing import Dict, Any, Optional, Tuple, Union, Type import openai import logging from contextlib import contextmanager import threading from pydantic import BaseModel, ConfigDict, Field from ell.store import Store +from ell.provider import Provider _config_logger = logging.getLogger(__name__) class Config(BaseModel): @@ -20,13 +21,14 @@ class Config(BaseModel): default_lm_params: Dict[str, Any] = Field(default_factory=dict, description="Default parameters for language models.") default_system_prompt: str = Field(default="You are a helpful AI assistant.", description="The default system prompt used for AI interactions.") default_client: Optional[openai.Client] = Field(default=None, description="The default OpenAI client used when a specific model client is not found.") + providers: Dict[Type, Type[Provider]] = Field(default_factory=dict, description="A dictionary mapping client types to provider classes.") def __init__(self, **data): super().__init__(**data) self._lock = threading.Lock() self._local = threading.local() - def register_model(self, model_name: str, client: openai.Client) -> None: + def register_model(self, model_name: str, client: Any) -> None: """ Register an OpenAI client for a specific model name. @@ -49,7 +51,7 @@ class Config(BaseModel): return self.store is not None @contextmanager - def model_registry_override(self, overrides: Dict[str, openai.Client]): + def model_registry_override(self, overrides: Dict[str, Any]): """ Temporarily override the model registry with new client mappings. @@ -70,7 +72,7 @@ class Config(BaseModel): finally: self._local.stack.pop() - def get_client_for(self, model_name: str) -> Optional[openai.Client]: + def get_client_for(self, model_name: str) -> Tuple[Optional[Any], bool]: """ Get the OpenAI client for a specific model name. @@ -154,6 +156,26 @@ class Config(BaseModel): """ self.default_client = client + def register_provider(self, provider_class: Type[Provider]) -> None: + """ + Register a provider class for a specific client type. + + :param provider_class: The provider class to register. + :type provider_class: Type[AbstractProvider] + """ + with self._lock: + self.providers[provider_class.get_client_type()] = provider_class + + def get_provider_for(self, client: Any) -> Optional[Type[Provider]]: + """ + Get the provider class for a specific client instance. + + :param client: The client instance to get the provider for. + :type client: Any + :return: The provider class for the specified client, or None if not found. + :rtype: Optional[Type[AbstractProvider]] + """ + return next((provider for client_type, provider in self.providers.items() if isinstance(client, client_type)), None) # Singleton instance config = Config() @@ -218,3 +240,6 @@ def set_default_system_prompt(*args, **kwargs) -> None: return config.set_default_system_prompt(*args, **kwargs) # You can add more helper functions here if needed +@wraps(config.register_provider) +def register_provider(*args, **kwargs) -> None: + return config.register_provider(*args, **kwargs) \ No newline at end of file diff --git a/src/ell/lmp/complex.py b/src/ell/lmp/complex.py index bca13d6..cc0088f 100644 --- a/src/ell/lmp/complex.py +++ b/src/ell/lmp/complex.py @@ -9,12 +9,10 @@ from ell.util.api import call from ell.util.verbosity import compute_color, model_usage_logger_pre -import openai - from functools import wraps from typing import Any, Dict, Optional, List, Callable, Union -def complex(model: str, client: Optional[openai.Client] = None, exempt_from_tracking=False, tools: Optional[List[Callable]] = None, post_callback: Optional[Callable] = None, **api_params): +def complex(model: str, client: Optional[Any] = None, exempt_from_tracking=False, tools: Optional[List[Callable]] = None, post_callback: Optional[Callable] = None, **api_params): """ A sophisticated language model programming decorator for complex LLM interactions. @@ -228,7 +226,7 @@ def complex(model: str, client: Optional[openai.Client] = None, exempt_from_trac def model_call( *fn_args, _invocation_origin : str = None, - client: Optional[openai.Client] = None, + client: Optional[Any] = None, lm_params: Optional[LMPParams] = {}, invocation_api_params=False, **fn_kwargs, diff --git a/src/ell/lmp/simple.py b/src/ell/lmp/simple.py index 2293669..a57566f 100644 --- a/src/ell/lmp/simple.py +++ b/src/ell/lmp/simple.py @@ -1,11 +1,10 @@ from functools import wraps -from typing import Optional +from typing import Any, Optional -import openai from ell.lmp.complex import complex -def simple(model: str, client: Optional[openai.Client] = None, exempt_from_tracking=False, **api_params): +def simple(model: str, client: Optional[Any] = None, exempt_from_tracking=False, **api_params): """ The fundamental unit of language model programming in ell. diff --git a/src/ell/provider.py b/src/ell/provider.py new file mode 100644 index 0000000..9350be2 --- /dev/null +++ b/src/ell/provider.py @@ -0,0 +1,56 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple, Type, Union +from ell.types import Message, ContentBlock, ToolCall +from ell.types._lstr import _lstr +import json +from dataclasses import dataclass +from ell.types.message import LMP + + +@dataclass +class APICallResult: + response: Any + actual_streaming: bool + actual_n: int + final_call_params: Dict[str, Any] + + +class Provider(ABC): + """ + Abstract base class for all providers. Providers are API interfaces to language models, not necessarily API providers. + For example, the OpenAI provider is an API interface to OpenAI's API but also to Ollama and Azure OpenAI. + """ + + @classmethod + @abstractmethod + def call_model( + cls, + client: Any, + model: str, + messages: List[Any], + api_params: Dict[str, Any], + tools: Optional[list[LMP]] = None, + ) -> APICallResult: + """Make the API call to the language model and return the result along with actual streaming, n values, and final call parameters.""" + pass + + @classmethod + @abstractmethod + def process_response( + cls, call_result: APICallResult, _invocation_origin: str, logger: Optional[Any] = None, tools: Optional[List[LMP]] = None, + ) -> Tuple[List[Message], Dict[str, Any]]: + """Process the API response and convert it to ell format.""" + pass + + @classmethod + @abstractmethod + def supports_streaming(cls) -> bool: + """Check if the provider supports streaming.""" + pass + + @classmethod + @abstractmethod + def get_client_type(cls) -> Type: + """Return the type of client this provider supports.""" + pass diff --git a/src/ell/providers/__init__.py b/src/ell/providers/__init__.py index e69de29..90ec14c 100644 --- a/src/ell/providers/__init__.py +++ b/src/ell/providers/__init__.py @@ -0,0 +1,9 @@ +import ell.providers.openai +# import ell.providers.anthropic +# import ell.providers.groq +# import ell.providers.mistral +# import ell.providers.cohere +# import ell.providers.gemini +# import ell.providers.elevenlabs +# import ell.providers.replicate +# import ell.providers.huggingface \ No newline at end of file diff --git a/src/ell/providers/openai.py b/src/ell/providers/openai.py new file mode 100644 index 0000000..2fd96c7 --- /dev/null +++ b/src/ell/providers/openai.py @@ -0,0 +1,239 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple, Type, Union +from ell.provider import APICallResult, Provider +from ell.types import Message, ContentBlock, ToolCall +from ell.types._lstr import _lstr +import json +from ell.configurator import config, register_provider +from ell.types.message import LMP +from ell.util.serialization import serialize_image + +try: + import openai + class OpenAIProvider(Provider): + @staticmethod + def content_block_to_openai_format(content_block: ContentBlock) -> Dict[str, Any]: + if content_block.image: + base64_image = serialize_image(content_block.image) + return { + "type": "image_url", + "image_url": { + "url": base64_image + } + } + elif content_block.text: + return { + "type": "text", + "text": content_block.text + } + elif content_block.parsed: + return { + "type": "text", + "text": content_block.parsed.model_dump_json() + } + # Tool calls handled in message_to_openai_format. + #XXX: Feel free to refactor this. + else: + return None + + @staticmethod + def message_to_openai_format(message: Message) -> Dict[str, Any]: + openai_message = { + "role": "tool" if message.tool_results else message.role, + "content": list(filter(None, [ + OpenAIProvider.content_block_to_openai_format(c) for c in message.content + ])) + } + if message.tool_calls: + try: + openai_message["tool_calls"] = [ + { + "id": tool_call.tool_call_id, + "type": "function", + "function": { + "name": tool_call.tool.__name__, + "arguments": json.dumps(tool_call.params.model_dump()) + } + } for tool_call in message.tool_calls + ] + except TypeError as e: + print(f"Error serializing tool calls: {e}. Did you fully type your @ell.tool decorated functions?") + raise + openai_message["content"] = None # Set content to null when there are tool calls + + if message.tool_results: + openai_message["tool_call_id"] = message.tool_results[0].tool_call_id + openai_message["content"] = message.tool_results[0].result[0].text + assert len(message.tool_results[0].result) == 1, "Tool result should only have one content block" + assert message.tool_results[0].result[0].type == "text", "Tool result should only have one text content block" + return openai_message + + @classmethod + def call_model( + cls, + client: Any, + model: str, + messages: List[Message], + api_params: Dict[str, Any], + tools: Optional[list[LMP]] = None, + ) -> APICallResult: + final_call_params = api_params.copy() + openai_messages = [cls.message_to_openai_format(message) for message in messages] + + actual_n = api_params.get("n", 1) + final_call_params["model"] = model + final_call_params["messages"] = openai_messages + + if final_call_params.get("response_format"): + final_call_params.pop("stream", None) + final_call_params.pop("stream_options", None) + response = client.beta.chat.completions.parse(**final_call_params) + else: + # Tools not workign with structured API + if tools: + final_call_params["tool_choice"] = "auto" + final_call_params["tools"] = [ + { + "type": "function", + "function": { + "name": tool.__name__, + "description": tool.__doc__, + "parameters": tool.__ell_params_model__.model_json_schema(), + }, + } + for tool in tools + ] + final_call_params.pop("stream", None) + final_call_params.pop("stream_options", None) + else: + final_call_params["stream_options"] = {"include_usage": True} + final_call_params["stream"] = True + + response = client.chat.completions.create(**final_call_params) + + + return APICallResult( + response=response, + actual_streaming=isinstance(response, openai.Stream), + actual_n=actual_n, + final_call_params=final_call_params, + ) + + @classmethod + def process_response( + cls, call_result: APICallResult, _invocation_origin: str, logger : Optional[Any] = None, tools: Optional[List[LMP]] = None, + ) -> Tuple[List[Message], Dict[str, Any]]: + choices_progress = defaultdict(list) + api_params = call_result.final_call_params + metadata = {} + #XXX: Remove logger and refactor this API + + if not call_result.actual_streaming: + response = [call_result.response] + else: + response = call_result.response + + for chunk in response: + if hasattr(chunk, "usage") and chunk.usage: + metadata = chunk.to_dict() + if call_result.actual_streaming: + continue + + + for choice in chunk.choices: + choices_progress[choice.index].append(choice) + if choice.index == 0 and logger: + # print(choice, streaming) + logger(choice.delta.content if call_result.actual_streaming else + choice.message.content or getattr(choice.message, "refusal", ""), is_refusal=getattr(choice.message, "refusal", False) if not call_result.actual_streaming else False) + + + + tracked_results = [] + for _, choice_deltas in sorted(choices_progress.items(), key=lambda x: x[0]): + content = [] + + if call_result.actual_streaming: + text_content = "".join( + (choice.delta.content or "" for choice in choice_deltas) + ) + if text_content: + content.append( + ContentBlock( + text=_lstr( + content=text_content, _origin_trace=_invocation_origin + ) + ) + ) + else: + choice = choice_deltas[0].message + if choice.refusal: + raise ValueError(choice.refusal) + if api_params.get("response_format"): + content.append(ContentBlock(parsed=choice.parsed)) + elif choice.content: + content.append( + ContentBlock( + text=_lstr( + content=choice.content, _origin_trace=_invocation_origin + ) + ) + ) + + if not call_result.actual_streaming and hasattr(choice, "tool_calls") and choice.tool_calls: + assert tools is not None, "Tools not provided, yet tool calls in response. Did you manually specify a tool spec without using ell.tool?" + for tool_call in choice.tool_calls: + matching_tool = next( + ( + tool + for tool in tools + if tool.__name__ == tool_call.function.name + ), + None, + ) + if matching_tool: + params = matching_tool.__ell_params_model__( + **json.loads(tool_call.function.arguments) + ) + content.append( + ContentBlock( + tool_call=ToolCall( + tool=matching_tool, + tool_call_id=_lstr( + tool_call.id, _origin_trace=_invocation_origin + ), + params=params, + ) + ) + ) + + tracked_results.append( + Message( + role=( + choice.role + if not call_result.actual_streaming + else choice_deltas[0].delta.role + ), + content=content, + ) + ) + + return tracked_results, metadata + + @classmethod + def supports_streaming(cls) -> bool: + return True + + @classmethod + def get_client_type(cls) -> Type: + return openai.Client + + @classmethod + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls.register() + + register_provider(OpenAIProvider) +except ImportError: + pass \ No newline at end of file diff --git a/src/ell/types/message.py b/src/ell/types/message.py index 3ad4391..25290e9 100644 --- a/src/ell/types/message.py +++ b/src/ell/types/message.py @@ -48,7 +48,7 @@ class ContentBlock(BaseModel): image: Optional[Union[PILImage.Image, str, np.ndarray]] = Field(default=None) audio: Optional[Union[np.ndarray, List[float]]] = Field(default=None) tool_call: Optional[ToolCall] = Field(default=None) - parsed: Optional[Union[Type[BaseModel], BaseModel]] = Field(default=None) + parsed: Optional[BaseModel] = Field(default=None) tool_result: Optional[ToolResult] = Field(default=None) @model_validator(mode='after') @@ -174,29 +174,6 @@ class ContentBlock(BaseModel): return serialize_image(image) - def to_openai_content_block(self): - if self.image: - base64_image = self.serialize_image(self.image, None) - return { - "type": "image_url", - "image_url": { - "url": base64_image - } - } - elif self.text: - return { - "type": "text", - "text": self.text - } - elif self.parsed: - return { - "type": "text", - "json": self.parsed.model_dump_json() - } - else: - return None - - def coerce_content_list(content: Union[str, List[ContentBlock], List[Union[str, ContentBlock, ToolCall, ToolResult, BaseModel]]] = None, **content_block_kwargs) -> List[ContentBlock]: if not content: content = [ContentBlock(**content_block_kwargs)] @@ -309,36 +286,6 @@ class Message(BaseModel): content = [c.tool_call.call_and_collect_as_message_block() for c in self.content if c.tool_call] return Message(role="user", content=content) - def to_openai_message(self) -> Dict[str, Any]: - - message = { - "role": "tool" if self.tool_results else self.role, - "content": list(filter(None, [ - c.to_openai_content_block() for c in self.content - ])) - } - if self.tool_calls: - message["tool_calls"] = [ - { - "id": tool_call.tool_call_id, - "type": "function", - "function": { - "name": tool_call.tool.__name__, - "arguments": json.dumps(tool_call.params.model_dump()) - } - } for tool_call in self.tool_calls - ] - message["content"] = None # Set content to null when there are tool calls - - if self.tool_results: - message["tool_call_id"] = self.tool_results[0].tool_call_id - # message["name"] = self.tool_results[0].tool_call_id.split('-')[0] # Assuming the tool name is the first part of the tool_call_id - message["content"] = self.tool_results[0].result[0].text - # Let';s assert no other type of content block in the tool result - assert len(self.tool_results[0].result) == 1, "Tool result should only have one content block" - assert self.tool_results[0].result[0].type == "text", "Tool result should only have one text content block" - return message - # HELPERS def system(content: Union[str, List[ContentBlock]]) -> Message: """ diff --git a/src/ell/types/studio.py b/src/ell/types/studio.py index 18f45cf..4ca0468 100644 --- a/src/ell/types/studio.py +++ b/src/ell/types/studio.py @@ -139,7 +139,7 @@ class InvocationContentsBase(SQLModel): ] total_size = sum( - len(json.dumps(field, default=(lambda x: x.model_dump_json() if isinstance(x, BaseModel) else str(x))).encode('utf-8')) for field in json_fields if field is not None + len(json.dumps(field, default=(lambda x: json.dumps(x.model_dump(), default=str) if isinstance(x, BaseModel) else str(x))).encode('utf-8')) for field in json_fields if field is not None ) # print("total_size", total_size) diff --git a/src/ell/util/_warnings.py b/src/ell/util/_warnings.py index a84e654..021246c 100644 --- a/src/ell/util/_warnings.py +++ b/src/ell/util/_warnings.py @@ -17,16 +17,16 @@ To fix this: import ell import openai - ell.lm(model, client=openai.Client(api_key=my_key)) + ell.simple(model, client=openai.Client(api_key=my_key)) def {name}(...): ... ``` * Or explicitly specify the client when the calling the LMP: ``` - ell.lm(model, client=openai.Client(api_key=my_key))(...) + ell.simple(model, client=openai.Client(api_key=my_key))(...) ``` -""" if long else " at time of definition. Can be okay if custom client specified later! ") + f"{Style.RESET_ALL}" +""" if long else " at time of definition. Can be okay if custom client specified later! https://docs.ell.so/core_concepts/models_and_api_clients.html ") + f"{Style.RESET_ALL}" def _warnings(model, fn, default_client_from_decorator): @@ -40,14 +40,14 @@ def _warnings(model, fn, default_client_from_decorator): * If this is a mistake either specify a client explicitly in the decorator: ```python import ell -ell.lm(model, client=my_client) +ell.simple(model, client=my_client) def {fn.__name__}(...): ... ``` or explicitly specify the client when the calling the LMP: ```python -ell.lm(model, client=my_client)(...) +ell.simple(model, client=my_client)(...) ``` {Style.RESET_ALL}""") elif (client_to_use := config.registry[model]) is None or not client_to_use.api_key: diff --git a/src/ell/util/api.py b/src/ell/util/api.py index 971de3a..5a78a4a 100644 --- a/src/ell/util/api.py +++ b/src/ell/util/api.py @@ -1,170 +1,63 @@ from functools import partial -import json -# import anthropic from ell.configurator import config -import openai + from collections import defaultdict from ell.types._lstr import _lstr from ell.types import Message, ContentBlock, ToolCall - -from typing import Any, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type from ell.types.message import LMP, LMPParams, MessageOrDict from ell.util.verbosity import model_usage_logger_post_end, model_usage_logger_post_intermediate, model_usage_logger_post_start from ell.util._warnings import _no_api_key_warning +from ell.provider import APICallResult, Provider import logging logger = logging.getLogger(__name__) - -def process_messages_for_client(messages: list[Message], client: Any): - if isinstance(client, openai.Client): - return [ - message.to_openai_message() - for message in messages] - # elif isinstance(client, anthropic.Anthropic): - # return messages - # XXX: or some such. - - def call( *, model: str, messages: list[Message], api_params: Dict[str, Any], tools: Optional[list[LMP]] = None, - client: Optional[openai.Client] = None, - _invocation_origin : str, + client: Optional[Any] = None, + _invocation_origin: str, _exempt_from_tracking: bool, - _logging_color=None, - _name: str = None, -) -> Tuple[Union[_lstr, Iterable[_lstr]], Optional[Dict[str, Any]]]: + _logging_color: Optional[str] = None, + _name: Optional[str] = None, +) -> Tuple[Union[Message, List[Message]], Dict[str, Any], Dict[str, Any]]: """ Helper function to run the language model with the provided messages and parameters. """ - # Todo: Decide if the client specified via the context amanger default registry is the shit or if the cliennt specified via lmp invocation args are the hing. if not client: client, was_fallback = config.get_client_for(model) if not client and not was_fallback: - # Someone registered you as None and you're trying to use this shit raise RuntimeError(_no_api_key_warning(model, _name, '', long=True, error=True)) - - metadata = dict() + if client is None: raise ValueError(f"No client found for model '{model}'. Ensure the model is registered using 'register_model' in 'config.py' or specify a client directly using the 'client' argument in the decorator or function call.") if not client.api_key: raise RuntimeError(_no_api_key_warning(model, _name, client, long=True, error=True)) - # todo: add suupport for streaming apis that dont give a final usage in the api - # print(api_params) - if api_params.get("response_format", False): - model_call = client.beta.chat.completions.parse - api_params.pop("stream", None) - api_params.pop("stream_options", None) - elif tools: - model_call = client.chat.completions.create - api_params["tools"] = [ - { - "type": "function", - "function": { - "name": tool.__name__, - "description": tool.__doc__, - "parameters": tool.__ell_params_model__.model_json_schema() - } - } for tool in tools - ] - api_params["tool_choice"] = "auto" - api_params.pop("stream", None) - api_params.pop("stream_options", None) - else: - model_call = client.chat.completions.create - api_params["stream"] = True - api_params["stream_options"] = {"include_usage": True} + provider_class: Type[Provider] = config.get_provider_for(client) + + + # XXX: Could actually delete htis + call_result = provider_class.call_model(client, model, messages, api_params, tools) - client_safe_messages_messages = process_messages_for_client(messages, client) - # print(api_params) - model_result = model_call( - model=model, messages=client_safe_messages_messages, **api_params - ) - streaming = api_params.get("stream", False) - if not streaming: - model_result = [model_result] - - choices_progress = defaultdict(list) - n = api_params.get("n", 1) - if config.verbose and not _exempt_from_tracking: - model_usage_logger_post_start(_logging_color, n) + model_usage_logger_post_start(_logging_color, call_result.actual_n) - with model_usage_logger_post_intermediate(_logging_color, n) as _logger: - for chunk in model_result: - if hasattr(chunk, "usage") and chunk.usage: - # Todo: is this a good decision. - metadata = chunk.to_dict() - - if streaming: - continue - - for choice in chunk.choices: - choices_progress[choice.index].append(choice) - if config.verbose and choice.index == 0 and not _exempt_from_tracking: - # print(choice, streaming) - _logger(choice.delta.content if streaming else - choice.message.content or getattr(choice.message, "refusal", ""), is_refusal=getattr(choice.message, "refusal", False) if not streaming else False) + with model_usage_logger_post_intermediate(_logging_color, call_result.actual_n) as _logger: + tracked_results, metadata = provider_class.process_response(call_result, _invocation_origin, _logger if config.verbose and not _exempt_from_tracking else None, tools) + + if config.verbose and not _exempt_from_tracking: model_usage_logger_post_end() - n_choices = len(choices_progress) - # coerce the streaming into a final message type - tracked_results = [] - for _, choice_deltas in sorted(choices_progress.items(), key=lambda x: x[0]): - content = [] - - # Handle text content - if streaming: - text_content = "".join((choice.delta.content or "" for choice in choice_deltas)) - if text_content: - content.append(ContentBlock( - text=_lstr(content=text_content, _origin_trace=_invocation_origin) - )) - else: - choice = choice_deltas[0].message - if choice.refusal: - raise ValueError(choice.refusal) - # XXX: is this the best practice? try catch a parser? - if api_params.get("response_format", False): - content.append(ContentBlock( - parsed=choice.parsed - )) - elif choice.content: - content.append(ContentBlock( - text=_lstr(content=choice.content, _origin_trace=_invocation_origin) - )) - - # Handle tool calls - if not streaming and hasattr(choice, 'tool_calls'): - for tool_call in choice.tool_calls or []: - matching_tool = None - for tool in tools: - if tool.__name__ == tool_call.function.name: - matching_tool = tool - break - - if matching_tool: - params = matching_tool.__ell_params_model__(**json.loads(tool_call.function.arguments)) - content.append(ContentBlock( - tool_call=ToolCall(tool=matching_tool, tool_call_id=_lstr(tool_call.id, _origin_trace=_invocation_origin), params=params) - )) - - tracked_results.append(Message( - role=choice.role if not streaming else choice_deltas[0].delta.role, - content=content - )) - api_params = dict(model=model, messages=client_safe_messages_messages, api_params=api_params) - - return tracked_results[0] if n_choices == 1 else tracked_results, api_params, metadata \ No newline at end of file + return (tracked_results[0] if len(tracked_results) == 1 else tracked_results), call_result.final_call_params, metadata \ No newline at end of file diff --git a/src/ell/util/verbosity.py b/src/ell/util/verbosity.py index 8141dfa..5ae3e6e 100644 --- a/src/ell/util/verbosity.py +++ b/src/ell/util/verbosity.py @@ -43,7 +43,7 @@ def check_version_and_log(): import ell try: - response = requests.get("https://docs.ell.so/_static/ell_version.txt", timeout=0.5) + response = requests.get("https://docs.ell.so/_static/ell_version.txt", timeout=0.1) if response.status_code == 200: latest_version = response.text.strip() if latest_version != ell.__version__: diff --git a/tests/test_openai_provider.py b/tests/test_openai_provider.py new file mode 100644 index 0000000..e8ff291 --- /dev/null +++ b/tests/test_openai_provider.py @@ -0,0 +1,129 @@ +import pytest +from unittest.mock import Mock, patch +from ell.provider import APICallResult +from ell.providers.openai import OpenAIProvider +from ell.types import Message, ContentBlock, ToolCall +from ell.types.message import LMP, ToolResult +from pydantic import BaseModel +import json +import ell +class DummyParams(BaseModel): + param1: str + param2: int + +@pytest.fixture +def mock_openai_client(): + return Mock() +import openai +def test_content_block_to_openai_format(): + # Test text content + text_block = ContentBlock(text="Hello, world!") + assert OpenAIProvider.content_block_to_openai_format(text_block) == { + "type": "text", + "text": "Hello, world!" + } + + # Test parsed content + class DummyParsed(BaseModel): + field: str + parsed_block = ContentBlock(parsed=DummyParsed(field="value")) + + + res = OpenAIProvider.content_block_to_openai_format(parsed_block) + assert res["type"] == "text" + assert (res["text"]) == '{"field":"value"}' + + + # Test image content (mocked) + with patch('ell.providers.openai.serialize_image', return_value="base64_image_data"): + # Test random image content + import numpy as np + from PIL import Image + + # Generate a random image + random_image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8) + pil_image = Image.fromarray(random_image) + + with patch('ell.providers.openai.serialize_image', return_value="random_base64_image_data"): + random_image_block = ContentBlock(image=pil_image) + assert OpenAIProvider.content_block_to_openai_format(random_image_block) == { + "type": "image_url", + "image_url": { + "url": "random_base64_image_data" + } + } + + +def test_message_to_openai_format(): + # Test simple message + simple_message = Message(role="user", content=[ContentBlock(text="Hello")]) + assert OpenAIProvider.message_to_openai_format(simple_message) == { + "role": "user", + "content": [{"type": "text", "text": "Hello"}] + } + + # Test message with tool calls + def dummy_tool(param1: str, param2: int): pass + tool_call = ToolCall(tool=dummy_tool, tool_call_id="123", params=DummyParams(param1="test", param2=42)) + tool_message = Message(role="assistant", content=[tool_call]) + formatted = OpenAIProvider.message_to_openai_format(tool_message) + assert formatted["role"] == "assistant" + assert formatted["content"] is None + assert len(formatted["tool_calls"]) == 1 + assert formatted["tool_calls"][0]["id"] == "123" + assert formatted["tool_calls"][0]["function"]["name"] == "dummy_tool" + assert json.loads(formatted["tool_calls"][0]["function"]["arguments"]) == {"param1": "test", "param2": 42} + + # Test message with tool results + tool_result_message = Message( + role="user", + content=[ToolResult(tool_call_id="123", result=[ContentBlock(text="Tool output")])], + ) + formatted = OpenAIProvider.message_to_openai_format(tool_result_message) + assert formatted["role"] == "tool" + assert formatted["tool_call_id"] == "123" + assert formatted["content"] == "Tool output" + +def test_call_model(mock_openai_client): + messages = [Message(role="user", content=[ContentBlock(text="Hello")], refusal=None)] + api_params = {"temperature": 0.7} + + # Mock the client's chat.completions.create method + mock_openai_client.chat.completions.create.return_value = Mock(choices=[Mock(message=Mock(content="Response", refusal=None))]) + + @ell.tool() + def dummy_tool(param1: str, param2: int): pass + + result = OpenAIProvider.call_model(mock_openai_client, "gpt-3.5-turbo", messages, api_params, tools=[dummy_tool]) + + assert isinstance(result, APICallResult) + assert not "stream" in result.final_call_params + assert not result.actual_streaming + assert result.actual_n == 1 + assert "messages" in result.final_call_params + assert result.final_call_params["model"] == "gpt-3.5-turbo" + + +def test_process_response(): + # Mock APICallResult + mock_response = Mock( + choices=[Mock(message=Mock(role="assistant", content="Hello, world!", refusal=None, tool_calls=None))] + ) + call_result = APICallResult( + response=mock_response, + actual_streaming=False, + actual_n=1, + final_call_params={} + ) + + processed_messages, metadata = OpenAIProvider.process_response(call_result, "test_origin") + + assert len(processed_messages) == 1 + assert processed_messages[0].role == "assistant" + assert len(processed_messages[0].content) == 1 + assert processed_messages[0].content[0].text == "Hello, world!" + +def test_supports_streaming(): + assert OpenAIProvider.supports_streaming() == True + +# Add more tests as needed for other methods and edge cases