o1 support

This commit is contained in:
William Guss
2024-09-12 11:05:45 -07:00
parent 96d7a67437
commit 4c786b50d6
8 changed files with 57 additions and 29 deletions

View File

@@ -10,9 +10,9 @@ def get_random_adjective():
def get_random_punctuation():
return random.choice(["!", "!!", "!!!"])
@ell.simple(model="gpt-3.5-turbo")
@ell.simple(model="o1-preview")
def hello(name: str):
"""You are a helpful and expressive assistant."""
# """You are a helpful and expressive assistant."""
adjective = get_random_adjective()
punctuation = get_random_punctuation()
return f"Say a {adjective} hello to {name}{punctuation}"

36
examples/o1.py Normal file
View File

@@ -0,0 +1,36 @@
import ell
@ell.simple(model="o1-preview")
def solve_complex_math_problem(equation: str, variables: dict, constraints: list, optimization_goal: str):
return f"""You are an expert mathematician and problem solver. Please solve the following complex mathematical problem:
Equation: {equation}
Variables: {variables}
Constraints: {constraints}
Optimization Goal: {optimization_goal}"""
@ell.simple(model="o1-preview")
def write_plot_code_for_problem_and_solution(solution :str):
return f"""You are an expert programmer and problem solver.
Please write code in python with matplotlib to plot the solution to the following problem: It should work in the terminal. Full script with imports.
IMPORTANT: Do not include any other text only the code.
Solution to plot: {solution}"""
def solve_and_plot(**kwargs):
solution = solve_complex_math_problem(**kwargs)
plot_code = write_plot_code_for_problem_and_solution(solution)
# remove backticks and ```python
plot_code = plot_code.replace("```python", "").replace("```", "").strip()
exec(plot_code)
return solution
if __name__ == "__main__":
ell.init(store='./logdir', autocommit=True, verbose=True)
result = solve_and_plot(
equation="y = ax^2 + bx + c",
variables={"a": 1, "b": -5, "c": 6},
constraints=["x >= 0", "x <= 10"],
optimization_goal="Find the minimum value of y within the given constraints"
)
print(result)

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "ell-ai"
version = "0.0.4"
version = "0.0.5"
description = "ell - the language model programming library"
authors = ["William Guss <will@lrsys.xyz>"]
license = "MIT"

View File

@@ -19,7 +19,6 @@ class Config(BaseModel):
autocommit: bool = Field(default=False, description="If True, enables automatic committing of changes to the store.")
lazy_versioning: bool = Field(default=True, description="If True, enables lazy versioning for improved performance.")
default_lm_params: Dict[str, Any] = Field(default_factory=dict, description="Default parameters for language models.")
default_system_prompt: str = Field(default="You are a helpful AI assistant.", description="The default system prompt used for AI interactions.")
default_client: Optional[openai.Client] = Field(default=None, description="The default OpenAI client used when a specific model client is not found.")
providers: Dict[Type, Type[Provider]] = Field(default_factory=dict, description="A dictionary mapping client types to provider classes.")
@@ -138,14 +137,7 @@ class Config(BaseModel):
"""
self.default_lm_params = params
def set_default_system_prompt(self, prompt: str) -> None:
"""
Set the default system prompt.
:param prompt: The default system prompt to set.
:type prompt: str
"""
self.default_system_prompt = prompt
def set_default_client(self, client: openai.Client) -> None:
"""
@@ -186,7 +178,6 @@ def init(
autocommit: bool = True,
lazy_versioning: bool = True,
default_lm_params: Optional[Dict[str, Any]] = None,
default_system_prompt: Optional[str] = None,
default_openai_client: Optional[openai.Client] = None
) -> None:
"""
@@ -202,8 +193,6 @@ def init(
:type lazy_versioning: bool
:param default_lm_params: Set default parameters for language models.
:type default_lm_params: Dict[str, Any], optional
:param default_system_prompt: Set the default system prompt.
:type default_system_prompt: str, optional
:param default_openai_client: Set the default OpenAI client.
:type default_openai_client: openai.Client, optional
"""
@@ -216,8 +205,7 @@ def init(
if default_lm_params is not None:
config.set_default_lm_params(**default_lm_params)
if default_system_prompt is not None:
config.set_default_system_prompt(default_system_prompt)
if default_openai_client is not None:
config.set_default_client(default_openai_client)
@@ -235,10 +223,6 @@ def set_store(*args, **kwargs) -> None:
def set_default_lm_params(*args, **kwargs) -> None:
return config.set_default_lm_params(*args, **kwargs)
@wraps(config.set_default_system_prompt)
def set_default_system_prompt(*args, **kwargs) -> None:
return config.set_default_system_prompt(*args, **kwargs)
# You can add more helper functions here if needed
@wraps(config.register_provider)
def register_provider(*args, **kwargs) -> None:

View File

@@ -266,9 +266,10 @@ def _get_messages(prompt_ret: Union[str, list[MessageOrDict]], prompt: LMP) -> l
Helper function to convert the output of an LMP into a list of Messages.
"""
if isinstance(prompt_ret, str):
return [
Message(role="system", content=[ContentBlock(text=_lstr(prompt.__doc__) or config.default_system_prompt)]),
Message(role="user", content=[ContentBlock(text=prompt_ret)]),
has_system_prompt = prompt.__doc__ is not None and prompt.__doc__.strip() != ""
messages = [Message(role="system", content=[ContentBlock(text=_lstr(prompt.__doc__) )])] if has_system_prompt else []
return messages + [
Message(role="user", content=[ContentBlock(text=prompt_ret)])
]
else:
assert isinstance(

View File

@@ -73,7 +73,9 @@ def register(client: openai.Client):
('gpt-3.5-turbo-instruct', 'system'),
('gpt-4-0613', 'openai'),
('gpt-4', 'openai'),
('gpt-4-0314', 'openai')
('gpt-4-0314', 'openai'),
('o1-preview', 'system'),
('o1-mini', 'system'),
]
for model_id, owned_by in model_data:
config.register_model(model_id, client)

View File

@@ -87,7 +87,16 @@ try:
final_call_params["model"] = model
final_call_params["messages"] = openai_messages
if final_call_params.get("response_format"):
if model == "o1-preview" or model == "o1-mini":
# Ensure no system messages are present
assert all(msg['role'] != 'system' for msg in final_call_params['messages']), "System messages are not allowed for o1-preview or o1-mini models"
response = client.chat.completions.create(**final_call_params)
final_call_params.pop("stream", None)
final_call_params.pop("stream_options", None)
elif final_call_params.get("response_format"):
final_call_params.pop("stream", None)
final_call_params.pop("stream_options", None)
response = client.beta.chat.completions.parse(**final_call_params)

View File

@@ -10,10 +10,6 @@
# @lm(model="gpt-4-turbo", temperature=0.1, max_tokens=5)
# def lmp_with_default_system_prompt(*args, **kwargs):
# return "Test user prompt"
# @lm(model="gpt-4-turbo", temperature=0.1, max_tokens=5)
# def lmp_with_docstring_system_prompt(*args, **kwargs):