mirror of
https://github.com/abetlen/llama-cpp-python.git
synced 2023-09-07 17:34:22 +03:00
Add support for chat completion
This commit is contained in:
@@ -517,6 +517,99 @@ class Llama:
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def _convert_text_completion_to_chat(
|
||||
self, completion: Completion
|
||||
) -> ChatCompletion:
|
||||
return {
|
||||
"id": "chat" + completion["id"],
|
||||
"object": "chat.completion",
|
||||
"created": completion["created"],
|
||||
"model": completion["model"],
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": completion["choices"][0]["text"],
|
||||
},
|
||||
"finish_reason": completion["choices"][0]["finish_reason"],
|
||||
}
|
||||
],
|
||||
"usage": completion["usage"],
|
||||
}
|
||||
|
||||
def _convert_text_completion_chunks_to_chat(
|
||||
self,
|
||||
chunks: Iterator[CompletionChunk],
|
||||
) -> Iterator[ChatCompletionChunk]:
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i == 0:
|
||||
yield {
|
||||
"id": "chat" + chunk["id"],
|
||||
"model": chunk["model"],
|
||||
"created": chunk["created"],
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
yield {
|
||||
"id": "chat" + chunk["id"],
|
||||
"model": chunk["model"],
|
||||
"created": chunk["created"],
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": chunk["choices"][0]["text"],
|
||||
},
|
||||
"finish_reason": chunk["choices"][0]["finish_reason"],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: List[ChatCompletionMessage],
|
||||
temperature: float = 0.8,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
stop: List[str] = [],
|
||||
max_tokens: int = 128,
|
||||
repeat_penalty: float = 1.1,
|
||||
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
||||
instructions = """Complete the following chat conversation between the user and the assistant. System messages should be strictly followed as additional instructions."""
|
||||
chat_history = "\n".join(
|
||||
f'{message["role"]} {message.get("user", "")}: {message["content"]}'
|
||||
for message in messages
|
||||
)
|
||||
PROMPT = f" \n\n### Instructions:{instructions}\n\n### Inputs:{chat_history}\n\n### Response:\nassistant: "
|
||||
PROMPT_STOP = ["###", "\nuser: ", "\nassistant: ", "\nsystem: "]
|
||||
completion_or_chunks = self(
|
||||
prompt=PROMPT,
|
||||
stop=PROMPT_STOP + stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
stream=stream,
|
||||
max_tokens=max_tokens,
|
||||
repeat_penalty=repeat_penalty,
|
||||
)
|
||||
if stream:
|
||||
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
||||
return self._convert_text_completion_chunks_to_chat(chunks)
|
||||
else:
|
||||
completion: Completion = completion_or_chunks # type: ignore
|
||||
return self._convert_text_completion_to_chat(completion)
|
||||
|
||||
def __del__(self):
|
||||
if self.ctx is not None:
|
||||
llama_cpp.llama_free(self.ctx)
|
||||
|
||||
Reference in New Issue
Block a user