mirror of
https://github.com/MadcowD/ell.git
synced 2024-09-22 16:14:36 +03:00
moving tess
This commit is contained in:
40
tests/test_dict_sync_meta.py
Normal file
40
tests/test_dict_sync_meta.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import pytest
|
||||
from ell.types import Message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def message():
|
||||
return Message(role="user", content="Initial content")
|
||||
|
||||
|
||||
def test_initialization():
|
||||
"""Test that the dictionary is correctly initialized with dataclass fields."""
|
||||
msg = Message(role="admin", content="Hello, world!")
|
||||
assert msg["role"] == "admin"
|
||||
assert msg["content"] == "Hello, world!"
|
||||
assert msg.role == "admin"
|
||||
assert msg.content == "Hello, world!"
|
||||
|
||||
|
||||
def test_attribute_modification(message):
|
||||
"""Test that modifications to attributes update the dictionary."""
|
||||
# Modify the attributes
|
||||
message.role = "moderator"
|
||||
message.content = "Updated content"
|
||||
# Check dictionary synchronization
|
||||
assert message["role"] == "moderator"
|
||||
assert message["content"] == "Updated content"
|
||||
assert message.role == "moderator"
|
||||
assert message.content == "Updated content"
|
||||
|
||||
|
||||
def test_dictionary_modification(message):
|
||||
"""Test that direct dictionary modifications do not break attribute access."""
|
||||
# Directly modify the dictionary
|
||||
message["role"] = "admin"
|
||||
message["content"] = "New content"
|
||||
# Check if the attributes are not affected (they should not be, as this is one-way sync)
|
||||
assert message.role == "user"
|
||||
assert message.content == "Initial content"
|
||||
assert message["role"] == "admin"
|
||||
assert message["content"] == "New content"
|
||||
97
tests/test_lmp_to_prompt.py
Normal file
97
tests/test_lmp_to_prompt.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Pytest for the LM function (mocks the openai api so we can pretend to generate completions through te typoical approach taken in the decorators (and adapters file.))
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from ell.decorators import DEFAULT_SYSTEM_PROMPT, lm
|
||||
from ell.types import Message, LMPParams
|
||||
|
||||
|
||||
@lm(model="gpt-4-turbo", provider=None, temperature=0.1, max_tokens=5)
|
||||
def lmp_with_default_system_prompt(*args, **kwargs):
|
||||
return "Test user prompt"
|
||||
|
||||
|
||||
@lm(model="gpt-4-turbo", provider=None, temperature=0.1, max_tokens=5)
|
||||
def lmp_with_docstring_system_prompt(*args, **kwargs):
|
||||
"""Test system prompt""" # I personally prefer this sysntax but it's nto formattable so I'm not sure if it's the best approach. I think we can leave this in as a legacy feature but the default docs should be using the ell.system, ell.user, ...
|
||||
|
||||
return "Test user prompt"
|
||||
|
||||
|
||||
@lm(model="gpt-4-turbo", provider=None, temperature=0.1, max_tokens=5)
|
||||
def lmp_with_message_fmt(*args, **kwargs):
|
||||
"""Just a normal doc stirng"""
|
||||
|
||||
return [
|
||||
Message(role="system", content="Test system prompt from message fmt"),
|
||||
Message(role="user", content="Test user prompt 3"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client_mock():
|
||||
with patch("ell.adapter.client.chat.completions.create") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def test_lm_decorator_with_params(client_mock):
|
||||
client_mock.return_value = MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="Mocked content"))]
|
||||
)
|
||||
result = lmp_with_default_system_prompt("input", lm_params=dict(temperature=0.5))
|
||||
# It should have been called twice
|
||||
print("client_mock was called with:", client_mock.call_args)
|
||||
client_mock.assert_called_with(
|
||||
model="gpt-4-turbo",
|
||||
messages=[
|
||||
Message(role="system", content=DEFAULT_SYSTEM_PROMPT),
|
||||
Message(role="user", content="Test user prompt"),
|
||||
],
|
||||
temperature=0.5,
|
||||
max_tokens=5,
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
assert result == "Mocked content"
|
||||
|
||||
|
||||
def test_lm_decorator_with_docstring_system_prompt(client_mock):
|
||||
client_mock.return_value = MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="Mocked content"))]
|
||||
)
|
||||
result = lmp_with_docstring_system_prompt("input", lm_params=dict(temperature=0.5))
|
||||
print("client_mock was called with:", client_mock.call_args)
|
||||
client_mock.assert_called_with(
|
||||
model="gpt-4-turbo",
|
||||
messages=[
|
||||
Message(role="system", content="Test system prompt"),
|
||||
Message(role="user", content="Test user prompt"),
|
||||
],
|
||||
temperature=0.5,
|
||||
max_tokens=5,
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
assert result == "Mocked content"
|
||||
|
||||
def test_lm_decorator_with_msg_fmt_system_prompt(client_mock):
|
||||
client_mock.return_value = MagicMock(
|
||||
choices=[
|
||||
MagicMock(message=MagicMock(content="Mocked content from msg fmt"))
|
||||
]
|
||||
)
|
||||
result = lmp_with_default_system_prompt(
|
||||
"input", lm_params=dict(temperature=0.5), message_format="msg fmt"
|
||||
)
|
||||
print("client_mock was called with:", client_mock.call_args)
|
||||
client_mock.assert_called_with(
|
||||
model="gpt-4-turbo",
|
||||
messages=[
|
||||
Message(role="system", content="Test system prompt from message fmt"),
|
||||
Message(role="user", content="Test user prompt 3"), # come on cursor.
|
||||
],
|
||||
temperature=0.5,
|
||||
max_tokens=5,
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
assert result == "Mocked content from msg fmt"
|
||||
106
tests/test_lstr.py
Normal file
106
tests/test_lstr.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from ell.lstr import lstr
|
||||
|
||||
|
||||
class TestLstr:
|
||||
def test_init(self):
|
||||
# Test initialization with string content only
|
||||
s = lstr("hello")
|
||||
assert str(s) == "hello"
|
||||
assert s.logits is None
|
||||
assert s._origin_trace == frozenset()
|
||||
|
||||
# Test initialization with logits and _origin_trace
|
||||
logits = np.array([0.1, 0.2])
|
||||
_origin_trace = "model1"
|
||||
s = lstr("world", logits=logits, _origin_trace=_origin_trace)
|
||||
assert str(s) == "world"
|
||||
assert np.array_equal(s.logits, logits)
|
||||
assert s._origin_trace == frozenset({_origin_trace})
|
||||
|
||||
def test_add(self):
|
||||
s1 = lstr("hello")
|
||||
s2 = lstr("world", _origin_trace="model2")
|
||||
assert isinstance(s1 + s2, str)
|
||||
result = s1 + s2
|
||||
assert str(result) == "helloworld"
|
||||
assert result.logits is None
|
||||
assert result._origin_trace == frozenset({"model2"})
|
||||
|
||||
def test_mod(self):
|
||||
s = lstr("hello %s")
|
||||
result = s % "world"
|
||||
assert str(result) == "hello world"
|
||||
assert result.logits is None
|
||||
assert result._origin_trace == frozenset()
|
||||
|
||||
def test_mul(self):
|
||||
s = lstr("ha", _origin_trace="model3")
|
||||
result = s * 3
|
||||
assert str(result) == "hahaha"
|
||||
assert result.logits is None
|
||||
assert result._origin_trace == frozenset({"model3"})
|
||||
|
||||
def test_getitem(self):
|
||||
s = lstr(
|
||||
"hello", logits=np.array([0.1, 0.2, 0.3, 0.4, 0.5]), _origin_trace="model4"
|
||||
)
|
||||
result = s[1:4]
|
||||
assert str(result) == "ell"
|
||||
assert result.logits is None
|
||||
assert result._origin_trace == frozenset({"model4"})
|
||||
|
||||
def test_upper(self):
|
||||
# Test upper method without _origin_trace and logits
|
||||
s = lstr("hello")
|
||||
result = s.upper()
|
||||
assert str(result) == "HELLO"
|
||||
assert result.logits is None
|
||||
assert result._origin_trace == frozenset()
|
||||
|
||||
# Test upper method with _origin_trace
|
||||
s = lstr("world", _origin_trace="model11")
|
||||
result = s.upper()
|
||||
assert str(result) == "WORLD"
|
||||
assert result.logits is None
|
||||
assert result._origin_trace == frozenset({"model11"})
|
||||
|
||||
def test_join(self):
|
||||
s = lstr(", ", _origin_trace="model5")
|
||||
parts = [lstr("hello"), lstr("world", _origin_trace="model6")]
|
||||
result = s.join(parts)
|
||||
assert str(result) == "hello, world"
|
||||
assert result.logits is None
|
||||
assert result._origin_trace == frozenset({"model5", "model6"})
|
||||
|
||||
def test_split(self):
|
||||
s = lstr("hello world", _origin_trace="model7")
|
||||
parts = s.split()
|
||||
assert [str(p) for p in parts] == ["hello", "world"]
|
||||
assert all(p.logits is None for p in parts)
|
||||
assert all(p._origin_trace == frozenset({"model7"}) for p in parts)
|
||||
|
||||
def test_partition(self):
|
||||
s = lstr("hello, world", _origin_trace="model8")
|
||||
part1, sep, part2 = s.partition(", ")
|
||||
assert (str(part1), str(sep), str(part2)) == ("hello", ", ", "world")
|
||||
assert all(p.logits is None for p in (part1, sep, part2))
|
||||
assert all(p._origin_trace == frozenset({"model8"}) for p in (part1, sep, part2))
|
||||
|
||||
def test_formatting(self):
|
||||
s = lstr("Hello {}!")
|
||||
filled = s.format(lstr("world", _origin_trace="model9"))
|
||||
assert str(filled) == "Hello world!"
|
||||
assert filled.logits is None
|
||||
assert filled._origin_trace == frozenset({"model9"})
|
||||
|
||||
def test_repr(self):
|
||||
s = lstr("test", logits=np.array([1.0]), _origin_trace="model10")
|
||||
assert "test" in repr(s)
|
||||
assert "model10" in repr(s._origin_trace)
|
||||
|
||||
|
||||
# Run the tests
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
Reference in New Issue
Block a user