diff --git a/poetry.lock b/poetry.lock index 8102836..8a74d2f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,24 @@ # This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. +[[package]] +name = "attrs" +version = "22.2.0" +description = "Classes Without Boilerplate" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "attrs-22.2.0-py3-none-any.whl", hash = "sha256:29e95c7f6778868dbd49170f98f8818f78f3dc5e0e37c0b1f474e3561b240836"}, + {file = "attrs-22.2.0.tar.gz", hash = "sha256:c9227bfc2f01993c03f68db37d1d15c9690188323c067c641f1a35ca58185f99"}, +] + +[package.extras] +cov = ["attrs[tests]", "coverage-enable-subprocess", "coverage[toml] (>=5.3)"] +dev = ["attrs[docs,tests]"] +docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope.interface"] +tests = ["attrs[tests-no-zope]", "zope.interface"] +tests-no-zope = ["cloudpickle", "cloudpickle", "hypothesis", "hypothesis", "mypy (>=0.971,<0.990)", "mypy (>=0.971,<0.990)", "pympler", "pympler", "pytest (>=4.3.0)", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-mypy-plugins", "pytest-xdist[psutil]", "pytest-xdist[psutil]"] + [[package]] name = "black" version = "23.1.0" @@ -328,6 +347,21 @@ files = [ {file = "docutils-0.19.tar.gz", hash = "sha256:33995a6753c30b7f577febfc2c50411fec6aac7f7ffeb7c4cfe5991072dcf9e6"}, ] +[[package]] +name = "exceptiongroup" +version = "1.1.1" +description = "Backport of PEP 654 (exception groups)" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.1.1-py3-none-any.whl", hash = "sha256:232c37c63e4f682982c8b6459f33a8981039e5fb8756b2074364e5055c498c9e"}, + {file = "exceptiongroup-1.1.1.tar.gz", hash = "sha256:d484c3090ba2889ae2928419117447a14daf3c1231d5e30d0aae34f354f01785"}, +] + +[package.extras] +test = ["pytest (>=6)"] + [[package]] name = "ghp-import" version = "2.1.0" @@ -415,6 +449,18 @@ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + [[package]] name = "jaraco-classes" version = "3.2.3" @@ -821,6 +867,22 @@ files = [ docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"] +[[package]] +name = "pluggy" +version = "1.0.0" +description = "plugin and hook calling mechanisms for python" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, + {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + [[package]] name = "pycparser" version = "2.21" @@ -864,6 +926,30 @@ files = [ markdown = ">=3.2" pyyaml = "*" +[[package]] +name = "pytest" +version = "7.2.2" +description = "pytest: simple powerful testing with Python" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-7.2.2-py3-none-any.whl", hash = "sha256:130328f552dcfac0b1cec75c12e3f005619dc5f874f0a06e8ff7263f0ee6225e"}, + {file = "pytest-7.2.2.tar.gz", hash = "sha256:c99ab0c73aceb050f68929bc93af19ab6db0558791c6a0715723abe9d0ade9d4"}, +] + +[package.dependencies] +attrs = ">=19.2.0" +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -1281,4 +1367,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "cffaf5e2e66ade4f429d0e938277d4fa2c4878ca7338c3c4f91721a7d3aff91b" +content-hash = "cc9babcdfdc3679a4d84f68912408a005619a576947b059146ed1b428850ece9" diff --git a/pyproject.toml b/pyproject.toml index c640615..ce54ea9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ twine = "^4.0.2" mkdocs = "^1.4.2" mkdocstrings = {extras = ["python"], version = "^0.20.0"} mkdocs-material = "^9.1.4" +pytest = "^7.2.2" [build-system] requires = [ diff --git a/tests/test_llama.py b/tests/test_llama.py new file mode 100644 index 0000000..6843ec6 --- /dev/null +++ b/tests/test_llama.py @@ -0,0 +1,79 @@ +import llama_cpp + +MODEL = "./vendor/llama.cpp/models/ggml-vocab.bin" + + +def test_llama(): + llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) + + assert llama + assert llama.ctx is not None + + text = b"Hello World" + + assert llama.detokenize(llama.tokenize(text)) == text + + +def test_llama_patch(monkeypatch): + llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) + + ## Set up mock function + def mock_eval(*args, **kwargs): + return 0 + + monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval) + + output_text = " jumps over the lazy dog." + output_tokens = llama.tokenize(output_text.encode("utf-8")) + token_eos = llama.token_eos() + n = 0 + + def mock_sample(*args, **kwargs): + nonlocal n + if n < len(output_tokens): + n += 1 + return output_tokens[n - 1] + else: + return token_eos + + monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample) + + text = "The quick brown fox" + + ## Test basic completion until eos + n = 0 # reset + completion = llama.create_completion(text, max_tokens=20) + assert completion["choices"][0]["text"] == output_text + assert completion["choices"][0]["finish_reason"] == "stop" + + ## Test streaming completion until eos + n = 0 # reset + chunks = llama.create_completion(text, max_tokens=20, stream=True) + assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text + assert completion["choices"][0]["finish_reason"] == "stop" + + ## Test basic completion until stop sequence + n = 0 # reset + completion = llama.create_completion(text, max_tokens=20, stop=["lazy"]) + assert completion["choices"][0]["text"] == " jumps over the " + assert completion["choices"][0]["finish_reason"] == "stop" + + ## Test streaming completion until stop sequence + n = 0 # reset + chunks = llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"]) + assert ( + "".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the " + ) + assert completion["choices"][0]["finish_reason"] == "stop" + + ## Test basic completion until length + n = 0 # reset + completion = llama.create_completion(text, max_tokens=2) + assert completion["choices"][0]["text"] == " j" + assert completion["choices"][0]["finish_reason"] == "length" + + ## Test streaming completion until length + n = 0 # reset + chunks = llama.create_completion(text, max_tokens=2, stream=True) + assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " j" + assert completion["choices"][0]["finish_reason"] == "length"